Initial commit
Browse files- README.md +4 -4
- vianu/__init__.py +71 -0
- vianu/spock/Dockerfile +24 -0
- vianu/spock/__init__.py +0 -0
- vianu/spock/__main__.py +158 -0
- vianu/spock/app/__init__.py +3 -0
- vianu/spock/app/app.py +592 -0
- vianu/spock/app/formatter.py +98 -0
- vianu/spock/assets/css/styles.css +191 -0
- vianu/spock/assets/head/scripts.html +11 -0
- vianu/spock/assets/images/favicon.png +0 -0
- vianu/spock/assets/images/spock_logo.png +0 -0
- vianu/spock/assets/images/spock_logo_circular.png +0 -0
- vianu/spock/launch_demo_app.py +18 -0
- vianu/spock/launch_demo_pipeline.py +21 -0
- vianu/spock/requirements.txt +9 -0
- vianu/spock/settings.py +32 -0
- vianu/spock/src/__init__.py +0 -0
- vianu/spock/src/base.py +277 -0
- vianu/spock/src/cli.py +33 -0
- vianu/spock/src/ner.py +221 -0
- vianu/spock/src/scraping.py +922 -0
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
short_description: Spotting clinical knowledge
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SpoCK
|
| 3 |
+
emoji: 🖖
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
short_description: Spotting clinical knowledge
|
vianu/__init__.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, List
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from gradio.events import Dependency
|
| 8 |
+
|
| 9 |
+
LOG_FMT = "%(asctime)s | %(name)s | %(funcName)s | %(levelname)s | %(message)s"
|
| 10 |
+
|
| 11 |
+
class BaseApp(ABC):
|
| 12 |
+
"""The abstract base class of the main gradio application."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
app_name: str | None = None,
|
| 17 |
+
favicon_path: Path | None = None,
|
| 18 |
+
allowed_paths: List[str] | None = None,
|
| 19 |
+
head_file: Path | None = None,
|
| 20 |
+
css_file: Path | None = None,
|
| 21 |
+
theme: gr.Theme | None = None,
|
| 22 |
+
local_state: Any | None = None,
|
| 23 |
+
session_state: Any | None = None,
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
app_name: The name of the application. Defaults to None.
|
| 28 |
+
favicon_path: The favicon file as a :class:`pathlib.Path`. Defaults to None.
|
| 29 |
+
head_file: Custom html code as a :class:`pathlib.Path` to a html file. Defaults to None.
|
| 30 |
+
css_file (Path, optional): Custom css as a :class:`pathlib.Path` to a css file. Defaults to None.
|
| 31 |
+
theme (gr.Theme, optional): The theme of the application. Defaults to None.
|
| 32 |
+
local_state (Any, optional): The local state, where data persists in the browser's localStorage even after the page is refreshed or closed. Should be a json-serializable value (accessible only through it's serialized form). Defaults to None.
|
| 33 |
+
session_state (Any, optional): The session state, where data persists across multiple submits within a page session. Defaults to None
|
| 34 |
+
"""
|
| 35 |
+
self.favicon_path = favicon_path
|
| 36 |
+
self.allowed_paths = allowed_paths
|
| 37 |
+
|
| 38 |
+
self._app_name = app_name
|
| 39 |
+
self._head_file = head_file
|
| 40 |
+
self._css_file = css_file
|
| 41 |
+
self._theme = theme
|
| 42 |
+
|
| 43 |
+
self._local_state = gr.BrowserState(local_state)
|
| 44 |
+
self._session_state = gr.State(session_state)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def setup_ui(self):
|
| 49 |
+
"""Set up the user interface."""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def register_events(self):
|
| 54 |
+
"""Register the events."""
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def make(self) -> Dependency:
|
| 59 |
+
with gr.Blocks(
|
| 60 |
+
title=self._app_name,
|
| 61 |
+
head_paths=self._head_file,
|
| 62 |
+
css_paths=self._css_file,
|
| 63 |
+
theme=self._theme,
|
| 64 |
+
) as demo:
|
| 65 |
+
self._local_state.render()
|
| 66 |
+
self._session_state.render()
|
| 67 |
+
self.setup_ui()
|
| 68 |
+
self.register_events()
|
| 69 |
+
|
| 70 |
+
demo.load()
|
| 71 |
+
return demo
|
vianu/spock/Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Set environment variables to prevent Python from writing .pyc files and to buffer stdout and stderr
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
# Set the working directory inside the container
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
ENV PYTHONPATH=/app
|
| 10 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 11 |
+
|
| 12 |
+
# Copy the files into the working directory
|
| 13 |
+
COPY ../../vianu/__init__.py /app/vianu/__init__.py
|
| 14 |
+
COPY ../../vianu/spock /app/vianu/spock
|
| 15 |
+
|
| 16 |
+
# Install dependencies from requirements.txt
|
| 17 |
+
RUN pip install --upgrade pip \
|
| 18 |
+
&& pip install -r vianu/spock/requirements.txt
|
| 19 |
+
|
| 20 |
+
# Expose the port your Gradio app will run on (default: 7860)
|
| 21 |
+
EXPOSE 7868
|
| 22 |
+
|
| 23 |
+
# Command to run the application
|
| 24 |
+
CMD ["python", "vianu/spock/launch_demo_app.py"]
|
vianu/spock/__init__.py
ADDED
|
File without changes
|
vianu/spock/__main__.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import Namespace
|
| 2 |
+
import asyncio
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import logging
|
| 5 |
+
import sys
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
from vianu import LOG_FMT
|
| 9 |
+
from vianu.spock.settings import SCRAPING_SOURCES, LOG_LEVEL
|
| 10 |
+
from vianu.spock.src.cli import parse_args
|
| 11 |
+
from vianu.spock.src.base import Setup, Document, SpoCK, FileHandler
|
| 12 |
+
from vianu.spock.src import scraping as scp
|
| 13 |
+
from vianu.spock.src import ner
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(format=LOG_FMT, level=LOG_LEVEL)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
async def _orchestrator(
|
| 21 |
+
args_: Namespace,
|
| 22 |
+
src_queue: asyncio.Queue,
|
| 23 |
+
scp_queue: asyncio.Queue,
|
| 24 |
+
ner_queue: asyncio.Queue,
|
| 25 |
+
scp_tasks: List[asyncio.Task],
|
| 26 |
+
ner_tasks: List[asyncio.Task],
|
| 27 |
+
) -> None:
|
| 28 |
+
"""Orchestrates the scraping and NER tasks.
|
| 29 |
+
|
| 30 |
+
It waits for all scraping tasks to finish, then sends a sentinel to the scp_queue for each ner task (which will
|
| 31 |
+
trigger the ner tasks to finish -> cf :func:`vianu.spock.src.ner.apply`).
|
| 32 |
+
"""
|
| 33 |
+
logger.debug('setting up orchestrator task')
|
| 34 |
+
|
| 35 |
+
# Insert sources into the source queue
|
| 36 |
+
sources = args_.source
|
| 37 |
+
for src in sources:
|
| 38 |
+
await src_queue.put(src)
|
| 39 |
+
|
| 40 |
+
# Insert sentinel for each scraping task
|
| 41 |
+
for _ in range(len(scp_tasks)):
|
| 42 |
+
await src_queue.put(None)
|
| 43 |
+
|
| 44 |
+
# Wait for all scraper tasks to finish and stop them
|
| 45 |
+
await src_queue.join()
|
| 46 |
+
try:
|
| 47 |
+
await asyncio.gather(*scp_tasks)
|
| 48 |
+
except asyncio.CancelledError:
|
| 49 |
+
logger.warning('scraping task(s) have previously been canceled')
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f'scraping task(s) failed with error: {e}')
|
| 52 |
+
raise e
|
| 53 |
+
for st in scp_tasks:
|
| 54 |
+
st.cancel()
|
| 55 |
+
|
| 56 |
+
# Insert sentinel for each NER
|
| 57 |
+
for _ in range(len(ner_tasks)):
|
| 58 |
+
await scp_queue.put(None)
|
| 59 |
+
|
| 60 |
+
# Wait for NER tasks to process all items and finish
|
| 61 |
+
await scp_queue.join()
|
| 62 |
+
try:
|
| 63 |
+
await asyncio.gather(*ner_tasks)
|
| 64 |
+
except asyncio.CancelledError:
|
| 65 |
+
logger.warning('ner task(s) have previously been canceled')
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f'ner task(s) failed with error: {e}')
|
| 68 |
+
raise e
|
| 69 |
+
for nt in ner_tasks:
|
| 70 |
+
nt.cancel()
|
| 71 |
+
|
| 72 |
+
# Insert sentinel into ner_queue to indicate end of processing
|
| 73 |
+
await ner_queue.put(None)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def setup_asyncio_framework(args_: Namespace) -> Tuple[asyncio.Queue, List[asyncio.Task], List[asyncio.Task], asyncio.Task]:
|
| 77 |
+
"""Set up the asyncio framework for the SpoCK application."""
|
| 78 |
+
# Set up arguments
|
| 79 |
+
if args_.source is None:
|
| 80 |
+
args_.source = SCRAPING_SOURCES
|
| 81 |
+
|
| 82 |
+
# Set up queues
|
| 83 |
+
src_queue = asyncio.Queue()
|
| 84 |
+
scp_queue = asyncio.Queue()
|
| 85 |
+
ner_queue = asyncio.Queue()
|
| 86 |
+
|
| 87 |
+
# Start tasks
|
| 88 |
+
scp_tasks = scp.create_tasks(args_=args_, queue_in=src_queue, queue_out=scp_queue)
|
| 89 |
+
ner_tasks = ner.create_tasks(args_=args_, queue_in=scp_queue, queue_out=ner_queue)
|
| 90 |
+
orc_task = asyncio.create_task(
|
| 91 |
+
_orchestrator(
|
| 92 |
+
args_=args_,
|
| 93 |
+
src_queue=src_queue,
|
| 94 |
+
scp_queue=scp_queue,
|
| 95 |
+
ner_queue=ner_queue,
|
| 96 |
+
scp_tasks=scp_tasks,
|
| 97 |
+
ner_tasks=ner_tasks,
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
return ner_queue, scp_tasks, ner_tasks, orc_task
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
async def _collector(ner_queue: asyncio.Queue) -> List[Document]:
|
| 104 |
+
"""Collect results from the NER queue."""
|
| 105 |
+
data = []
|
| 106 |
+
while True:
|
| 107 |
+
item = await ner_queue.get()
|
| 108 |
+
|
| 109 |
+
# Check stopping condition
|
| 110 |
+
if item is None:
|
| 111 |
+
ner_queue.task_done()
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
# Append document to data
|
| 115 |
+
data.append(item.doc)
|
| 116 |
+
ner_queue.task_done()
|
| 117 |
+
|
| 118 |
+
return data
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
async def main(args_: Namespace | None = None, save: bool = True) -> None:
|
| 122 |
+
"""Main function for the SpoCK pipeline."""
|
| 123 |
+
started_at = datetime.now()
|
| 124 |
+
if args_ is None:
|
| 125 |
+
args_= parse_args(sys.argv[1:])
|
| 126 |
+
|
| 127 |
+
logging.basicConfig(level=args_.log_level.upper(), format=LOG_FMT)
|
| 128 |
+
logger.info(f'starting SpoCK (args_={args_})')
|
| 129 |
+
|
| 130 |
+
# Set up async structure (scraping queue/tasks, NER queue/tasks, orchestrator task)
|
| 131 |
+
ner_queue, _, _, _ = setup_asyncio_framework(args_)
|
| 132 |
+
|
| 133 |
+
# Set up collector task and wait for it to finish
|
| 134 |
+
# NOTE: if collector task is finished, the orchestrator is also finished (because of the sentinel in `ner_queue`)
|
| 135 |
+
# and therefore so are the scraping and NER tasks
|
| 136 |
+
col_task = asyncio.create_task(_collector(ner_queue))
|
| 137 |
+
data = await col_task
|
| 138 |
+
await ner_queue.join()
|
| 139 |
+
|
| 140 |
+
# Save data
|
| 141 |
+
if save:
|
| 142 |
+
file_name = args_.file_name
|
| 143 |
+
file_path = args_.file_path
|
| 144 |
+
spock = SpoCK(
|
| 145 |
+
id_=str(args_),
|
| 146 |
+
status='completed',
|
| 147 |
+
started_at=started_at,
|
| 148 |
+
finished_at=datetime.now(),
|
| 149 |
+
setup=Setup.from_namespace(args_),
|
| 150 |
+
data=data,
|
| 151 |
+
)
|
| 152 |
+
if file_name is not None and file_path is not None:
|
| 153 |
+
FileHandler(file_path=file_path).write(file_name=file_name, spock=spock)
|
| 154 |
+
logger.info('finished SpoCK')
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == '__main__':
|
| 158 |
+
asyncio.run(main())
|
vianu/spock/app/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vianu.spock.app.app import App
|
| 2 |
+
|
| 3 |
+
__all__ = ['App']
|
vianu/spock/app/app.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, List, Tuple
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
from vianu.spock.settings import LOG_LEVEL, N_SCP_TASKS, N_NER_TASKS
|
| 13 |
+
from vianu.spock.settings import LARGE_LANGUAGE_MODELS, SCRAPING_SOURCES, MAX_DOCS_SRC
|
| 14 |
+
from vianu.spock.settings import GRADIO_APP_NAME, GRADIO_SERVER_PORT, GRADIO_MAX_JOBS, GRADIO_UPDATE_INTERVAL
|
| 15 |
+
from vianu.spock.settings import OLLAMA_BASE_URL_ENV_NAME, OPENAI_API_KEY_ENV_NAME
|
| 16 |
+
from vianu.spock.src.base import Setup, SpoCK, SpoCKList, QueueItem # noqa: F401
|
| 17 |
+
from vianu import BaseApp
|
| 18 |
+
from vianu.spock.__main__ import setup_asyncio_framework
|
| 19 |
+
import vianu.spock.app.formatter as fmt
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
# App settings
|
| 25 |
+
_ASSETS_PATH = Path(__file__).parents[1] / "assets"
|
| 26 |
+
_UI_SETTINGS_LLM_CHOICES = [(name, value) for name, value in zip(['Ollama', 'OpenAI'], LARGE_LANGUAGE_MODELS)]
|
| 27 |
+
_UI_SETTINGS_SOURCE_CHOICES = [(name, value) for name, value in zip(['PubMed', 'EMA', 'MHRA', 'FDA'], SCRAPING_SOURCES)]
|
| 28 |
+
if not len(_UI_SETTINGS_LLM_CHOICES) == len(LARGE_LANGUAGE_MODELS):
|
| 29 |
+
raise ValueError('LARGE_LANGUAGE_MODELS and _UI_SETTINGS_LLM_CHOICES must have the same length')
|
| 30 |
+
if not len(_UI_SETTINGS_SOURCE_CHOICES) == len(SCRAPING_SOURCES):
|
| 31 |
+
raise ValueError('SCRAPING_SOURCES and _UI_SETTINGS_SOURCE_CHOICES must have the same length')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class LocalState:
|
| 36 |
+
"""The persistent local state."""
|
| 37 |
+
log_level: str = LOG_LEVEL
|
| 38 |
+
n_scp_tasks: int = N_SCP_TASKS
|
| 39 |
+
n_ner_tasks: int = N_NER_TASKS
|
| 40 |
+
max_jobs: int = GRADIO_MAX_JOBS
|
| 41 |
+
update_interval: float = GRADIO_UPDATE_INTERVAL
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class SessionState:
|
| 46 |
+
"""The session dependent state."""
|
| 47 |
+
|
| 48 |
+
# Asyncio setup
|
| 49 |
+
ner_queue: asyncio.Queue | None = None
|
| 50 |
+
scp_tasks: List[asyncio.Task] = field(default_factory=list)
|
| 51 |
+
ner_tasks: List[asyncio.Task] = field(default_factory=list)
|
| 52 |
+
orc_task: asyncio.Task | None = None
|
| 53 |
+
col_task: asyncio.Task | None = None
|
| 54 |
+
|
| 55 |
+
# Data
|
| 56 |
+
is_running: bool = False
|
| 57 |
+
spocks: SpoCKList = field(default_factory=list)
|
| 58 |
+
_index_running_spock: int | None = None
|
| 59 |
+
_index_active_spock: int | None = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def set_running_spock(self, index: int | None) -> None:
|
| 63 |
+
self._index_running_spock = index
|
| 64 |
+
|
| 65 |
+
def get_running_spock(self) -> SpoCK:
|
| 66 |
+
return self.spocks[self._index_running_spock]
|
| 67 |
+
|
| 68 |
+
def set_active_spock(self, index: int | None) -> None:
|
| 69 |
+
self._index_active_spock = index
|
| 70 |
+
|
| 71 |
+
def get_active_spock(self) -> SpoCK:
|
| 72 |
+
return self.spocks[self._index_active_spock]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class App(BaseApp):
|
| 76 |
+
"""The main gradio application."""
|
| 77 |
+
|
| 78 |
+
def __init__(self):
|
| 79 |
+
super().__init__(
|
| 80 |
+
app_name=GRADIO_APP_NAME,
|
| 81 |
+
favicon_path=_ASSETS_PATH / "images" / "favicon.png",
|
| 82 |
+
allowed_paths=[str(_ASSETS_PATH.resolve())],
|
| 83 |
+
head_file=_ASSETS_PATH / "head" / "scripts.html",
|
| 84 |
+
css_file=_ASSETS_PATH / "css" / "styles.css",
|
| 85 |
+
theme=gr.themes.Soft(),
|
| 86 |
+
local_state=LocalState(),
|
| 87 |
+
session_state=SessionState(),
|
| 88 |
+
)
|
| 89 |
+
self._components: Dict[str, Any] = {}
|
| 90 |
+
|
| 91 |
+
# --------------------------------------------------------------------------
|
| 92 |
+
# User Interface
|
| 93 |
+
# --------------------------------------------------------------------------
|
| 94 |
+
@staticmethod
|
| 95 |
+
def _ui_top_row():
|
| 96 |
+
with gr.Row(elem_classes="top-row"):
|
| 97 |
+
with gr.Column(scale=1):
|
| 98 |
+
gr.Image(
|
| 99 |
+
value=_ASSETS_PATH / "images" / "spock_logo_circular.png",
|
| 100 |
+
show_label=False,
|
| 101 |
+
elem_classes="image",
|
| 102 |
+
)
|
| 103 |
+
with gr.Column(scale=5):
|
| 104 |
+
value = """<div class='top-row title-desc'>
|
| 105 |
+
<div class='top-row title-desc title'>SpoCK: Spotting Clinical Knowledge</div>
|
| 106 |
+
<div class='top-row title-desc desc'><em>A tool for identifying <b>medicinal products</b> and <b>adverse drug reactions</b> inside publicly available literature</em></div>
|
| 107 |
+
</div>
|
| 108 |
+
"""
|
| 109 |
+
gr.Markdown(value=value)
|
| 110 |
+
|
| 111 |
+
def _ui_corpus_settings(self):
|
| 112 |
+
"""Settings column."""
|
| 113 |
+
with gr.Column(scale=1):
|
| 114 |
+
with gr.Accordion(label='LLM Endpoint'):
|
| 115 |
+
self._components['settings.llm_radio'] = gr.Radio(
|
| 116 |
+
label='Model', show_label=False, choices=_UI_SETTINGS_LLM_CHOICES, value='llama', interactive=True
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# 'llama' specific settings
|
| 120 |
+
with gr.Group(visible=True) as self._components['settings.ollama_group']:
|
| 121 |
+
value = os.environ.get("OLLAMA_BASE_URL")
|
| 122 |
+
placeholder = 'base_url of ollama endpoint' if value is None else None
|
| 123 |
+
gr.Markdown('---')
|
| 124 |
+
self._components['settings.ollama_base_url'] = gr.Textbox(
|
| 125 |
+
label='base_url',
|
| 126 |
+
show_label=False,
|
| 127 |
+
info='base_url',
|
| 128 |
+
placeholder=placeholder,
|
| 129 |
+
value=value,
|
| 130 |
+
interactive=True,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# 'openai' specific settings
|
| 134 |
+
with gr.Group(visible=False) as self._components['settings.openai_group']:
|
| 135 |
+
value = os.environ.get("OPENAI_API_KEY")
|
| 136 |
+
placeholder = 'api_key of openai endpoint' if value is None else None
|
| 137 |
+
logger.debug(f'openai api_key={value}')
|
| 138 |
+
gr.Markdown('---')
|
| 139 |
+
self._components['settings.openai_api_key'] = gr.Textbox(
|
| 140 |
+
label='api_key',
|
| 141 |
+
show_label=False,
|
| 142 |
+
info='api_key',
|
| 143 |
+
placeholder=placeholder,
|
| 144 |
+
value=value,
|
| 145 |
+
interactive=True,
|
| 146 |
+
type='password',
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with gr.Accordion(label='Sources', open=True):
|
| 150 |
+
self._components['settings.source'] = gr.CheckboxGroup(
|
| 151 |
+
label='Sources', show_label=False, choices=_UI_SETTINGS_SOURCE_CHOICES, value=SCRAPING_SOURCES, interactive=True
|
| 152 |
+
)
|
| 153 |
+
self._components['settings.max_docs_src'] = gr.Number(
|
| 154 |
+
label='max_docs_src', show_label=False, info='max. number of documents per source', value=MAX_DOCS_SRC, interactive=True
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def _ui_corpus_row(self):
|
| 158 |
+
"""Main corpus with settings, search field, job cards, and details"""
|
| 159 |
+
with gr.Row(elem_classes="bottom-container"):
|
| 160 |
+
self._ui_corpus_settings()
|
| 161 |
+
self._ui_corpus_main()
|
| 162 |
+
|
| 163 |
+
def _ui_corpus_main(self):
|
| 164 |
+
"""Search field, job cards, and details."""
|
| 165 |
+
with gr.Column(scale=5):
|
| 166 |
+
# Search text field and start/stop/cancel buttons
|
| 167 |
+
with gr.Row(elem_classes="search-container"):
|
| 168 |
+
with gr.Column(scale=3):
|
| 169 |
+
self._components['main.search_term'] = gr.Textbox(
|
| 170 |
+
label="Search", show_label=False, placeholder="Enter your search term"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
with gr.Column(scale=1, elem_classes='pipeline-button'):
|
| 174 |
+
self._components['main.start_button'] = gr.HTML('<div class="button-not-running">Start</div>', visible=True)
|
| 175 |
+
self._components['main.stop_button'] = gr.HTML('<div class="button-running">Stop</div>', visible=False)
|
| 176 |
+
self._components['main.cancel_button'] = gr.HTML('<div class="canceling">canceling...</div>', visible=False)
|
| 177 |
+
|
| 178 |
+
# Job summary cards
|
| 179 |
+
with gr.Row(elem_classes="jobs-container"):
|
| 180 |
+
self._components['main.cards'] = [gr.HTML('', elem_id=f'job-{i}', visible=False) for i in range(GRADIO_MAX_JOBS)]
|
| 181 |
+
|
| 182 |
+
# Details of the selected job
|
| 183 |
+
with gr.Row():
|
| 184 |
+
self._components['main.details'] = gr.HTML('<div class="details-container"></div>')
|
| 185 |
+
|
| 186 |
+
def setup_ui(self):
|
| 187 |
+
"""Set up the user interface."""
|
| 188 |
+
self._ui_top_row()
|
| 189 |
+
self._ui_corpus_row()
|
| 190 |
+
self._components['timer'] = gr.Timer(value=GRADIO_UPDATE_INTERVAL, active=False, render=True)
|
| 191 |
+
|
| 192 |
+
# --------------------------------------------------------------------------
|
| 193 |
+
# Helpers
|
| 194 |
+
# --------------------------------------------------------------------------
|
| 195 |
+
@staticmethod
|
| 196 |
+
def _show_llm_settings(llm: str) -> Tuple[dict[str, Any], dict[str, Any]]:
|
| 197 |
+
logger.debug(f'show {llm} model settings')
|
| 198 |
+
if llm == 'llama':
|
| 199 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 200 |
+
elif llm == 'openai':
|
| 201 |
+
return gr.update(visible=False), gr.update(visible=True)
|
| 202 |
+
else:
|
| 203 |
+
return gr.update(visible=False), gr.update(visible=False)
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def _set_ollama_base_url(base_url: str) -> None:
|
| 207 |
+
"""Setup ollama base_url as environment variable."""
|
| 208 |
+
logger.debug(f'set ollama base_url environment variable ({OLLAMA_BASE_URL_ENV_NAME}={base_url})')
|
| 209 |
+
os.environ[OLLAMA_BASE_URL_ENV_NAME] = base_url
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def _set_openai_api_key(api_key: str) -> None:
|
| 213 |
+
"""Setup openai api_key as environment variable."""
|
| 214 |
+
log_key = '*****' if api_key else 'None'
|
| 215 |
+
logger.debug(f'set openai api key (api_key={log_key})')
|
| 216 |
+
os.environ[OPENAI_API_KEY_ENV_NAME] = api_key
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
def _feed_cards_to_ui(local_state: dict, session_state: SessionState) -> List[dict[str, Any]]:
|
| 220 |
+
"""For all existing SpoCKs, create and feed the job cards to the UI."""
|
| 221 |
+
spocks = session_state.spocks
|
| 222 |
+
logger.debug(f'feeding cards to UI (len(spocks)={len(spocks)})')
|
| 223 |
+
|
| 224 |
+
# Create the job cards for the existing spocks
|
| 225 |
+
cds = []
|
| 226 |
+
for i, spk in enumerate(spocks):
|
| 227 |
+
html = fmt.get_job_card_html(i, spk)
|
| 228 |
+
cds.append(gr.update(value=html, visible=True))
|
| 229 |
+
|
| 230 |
+
# Extdend with not-visible cards
|
| 231 |
+
# Note: for gradio >= 5.0.0 this logic could be replaces with dynamic number of gr.Blocks
|
| 232 |
+
# (see https://www.gradio.app/guides/dynamic-apps-with-render-decorator)
|
| 233 |
+
cds.extend([gr.update(visible=False) for _ in range(local_state['max_jobs'] - len(spocks))])
|
| 234 |
+
return cds
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def _feed_details_to_ui(session_state: SessionState) -> str:
|
| 238 |
+
"""Collect the html texts for the documents of the selected job and feed them to the UI."""
|
| 239 |
+
if len(session_state.spocks) == 0:
|
| 240 |
+
return fmt.get_details_html([])
|
| 241 |
+
|
| 242 |
+
active_spock = session_state.get_active_spock()
|
| 243 |
+
logger.debug(f'feeding details to UI (len(data)={len(active_spock.data)})')
|
| 244 |
+
return fmt.get_details_html(active_spock.data)
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def _check_llm_settings(llm: str) -> None:
|
| 248 |
+
"""Check if the LLM settings are set."""
|
| 249 |
+
if llm == 'llama':
|
| 250 |
+
if os.environ.get(OLLAMA_BASE_URL_ENV_NAME) is None:
|
| 251 |
+
raise gr.Error('Ollama base_url is not set (submit value with Enter)')
|
| 252 |
+
elif llm == 'openai':
|
| 253 |
+
if os.environ.get(OPENAI_API_KEY_ENV_NAME) is None:
|
| 254 |
+
raise gr.Error('OpenAI api_key is not set (submit value with Enter)')
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def _toggle_button(session_state: SessionState) -> Tuple[SessionState, dict[str, Any], dict[str, Any], dict[str, Any]]:
|
| 258 |
+
"""Toggle the state of the pipleline between running <-> not_running.
|
| 259 |
+
|
| 260 |
+
As a result the corresponding buttons (Start, Stop, canceling...) are shown/hidden.
|
| 261 |
+
"""
|
| 262 |
+
logger.debug(f'toggle button (is_running={session_state.is_running}->{not session_state.is_running})')
|
| 263 |
+
session_state.is_running = not session_state.is_running
|
| 264 |
+
if session_state.is_running:
|
| 265 |
+
# Show the stop button and hide the start/cancel button
|
| 266 |
+
return session_state, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
|
| 267 |
+
else:
|
| 268 |
+
# Show the start button and hide the stop/cancel button
|
| 269 |
+
return session_state, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
|
| 270 |
+
|
| 271 |
+
@staticmethod
|
| 272 |
+
def _show_cancel_button() -> Tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
| 273 |
+
"""Shows the cancel button and hides the start and stop button."""
|
| 274 |
+
logger.debug('show cancel button')
|
| 275 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
| 276 |
+
|
| 277 |
+
@staticmethod
|
| 278 |
+
def _setup_spock(term: str, model: str, source: List[str], max_docs_src: int, local_state: dict, session_state: SessionState) -> SessionState:
|
| 279 |
+
"""Setup a new SpoCK object."""
|
| 280 |
+
max_jobs = local_state['max_jobs']
|
| 281 |
+
spocks = session_state.spocks
|
| 282 |
+
|
| 283 |
+
# Check if the maximum number of jobs is reached and pop the last job if necessary
|
| 284 |
+
if len(spocks) >= max_jobs:
|
| 285 |
+
msg = f'max number of jobs ({max_jobs}); last job "{spocks[-1].setup.term}" is removed'
|
| 286 |
+
gr.Warning(msg)
|
| 287 |
+
logger.warning(msg)
|
| 288 |
+
spocks.pop(-1)
|
| 289 |
+
|
| 290 |
+
# Setup the running_spock and append it to the list of spocks
|
| 291 |
+
msg = f'started SpoCK for "{term}"'
|
| 292 |
+
gr.Info(msg)
|
| 293 |
+
logger.info(msg)
|
| 294 |
+
|
| 295 |
+
# Create new SpoCK object
|
| 296 |
+
setup = Setup(
|
| 297 |
+
id_=f'{term} {source} {model}',
|
| 298 |
+
term=term,
|
| 299 |
+
model=model,
|
| 300 |
+
source=source,
|
| 301 |
+
max_docs_src=max_docs_src,
|
| 302 |
+
log_level=local_state['log_level'],
|
| 303 |
+
n_scp_tasks=local_state['n_scp_tasks'],
|
| 304 |
+
n_ner_tasks=local_state['n_ner_tasks'],
|
| 305 |
+
)
|
| 306 |
+
spock = SpoCK(
|
| 307 |
+
id_=setup.id_,
|
| 308 |
+
status='running',
|
| 309 |
+
setup=setup,
|
| 310 |
+
started_at=setup.submission,
|
| 311 |
+
data=[]
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Set the running and focused SpoCK to be the new SpoCK object
|
| 315 |
+
index = 0
|
| 316 |
+
spocks.insert(index, spock)
|
| 317 |
+
session_state.set_running_spock(index=index)
|
| 318 |
+
session_state.set_active_spock(index=index)
|
| 319 |
+
return session_state
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
async def _collector(session_state: SessionState) -> None:
|
| 323 |
+
"""Append the processed document(s) from the `session_state.ner_queue` to the running spock."""
|
| 324 |
+
running_spock = session_state.get_running_spock()
|
| 325 |
+
ner_queue = session_state.ner_queue
|
| 326 |
+
logger.debug(f'starting collector (term={running_spock.setup.term})')
|
| 327 |
+
|
| 328 |
+
while True:
|
| 329 |
+
item = await ner_queue.get() # type: QueueItem
|
| 330 |
+
# Check stopping condition (added by the `orchestrator` in `vianu.spock.__main__`)
|
| 331 |
+
if item is None:
|
| 332 |
+
ner_queue.task_done()
|
| 333 |
+
break
|
| 334 |
+
running_spock.data.append(item.doc)
|
| 335 |
+
ner_queue.task_done()
|
| 336 |
+
|
| 337 |
+
async def _setup_asyncio_framework(self, session_state: SessionState) -> SessionState:
|
| 338 |
+
""""Start the SpoCK processes by setting up the asyncio framework and starting the asyncio tasks.
|
| 339 |
+
|
| 340 |
+
Main components of asyncio framework are:
|
| 341 |
+
- ner_queue: queue for collecting results from named entity recognition tasks
|
| 342 |
+
- scp_tasks: scraping tasks (cf. `vianu.spock.src.scp`)
|
| 343 |
+
- ner_tasks: named entity recognition tasks (cf. `vianu.spock.src.ner`)
|
| 344 |
+
- orc_task: orchestrating the process
|
| 345 |
+
- col_task: collect and assemble the final results
|
| 346 |
+
"""
|
| 347 |
+
logger.info("setting up asyncio framework")
|
| 348 |
+
|
| 349 |
+
# Setup asyncio tasks as in `vianu.spock.__main__`
|
| 350 |
+
args_ = session_state.get_running_spock().setup.to_namespace()
|
| 351 |
+
ner_queue, scp_tasks, ner_tasks, orc_task = setup_asyncio_framework(args_=args_)
|
| 352 |
+
session_state.ner_queue = ner_queue
|
| 353 |
+
session_state.scp_tasks = scp_tasks
|
| 354 |
+
session_state.ner_tasks = ner_tasks
|
| 355 |
+
session_state.orc_task = orc_task
|
| 356 |
+
|
| 357 |
+
# Setup the app specific collection task
|
| 358 |
+
col_task = asyncio.create_task(self._collector(session_state=session_state))
|
| 359 |
+
session_state.col_task = col_task
|
| 360 |
+
|
| 361 |
+
return session_state
|
| 362 |
+
|
| 363 |
+
@staticmethod
|
| 364 |
+
async def _conclusion(session_state: SessionState) -> SessionState:
|
| 365 |
+
# Wait collector task to finish and join ner_queue
|
| 366 |
+
try:
|
| 367 |
+
await session_state.col_task
|
| 368 |
+
except asyncio.CancelledError:
|
| 369 |
+
logger.warning('collector task canceled')
|
| 370 |
+
return session_state # This stops the _conclusion step in the case the _canceling step was triggered
|
| 371 |
+
except Exception as e:
|
| 372 |
+
logger.error(f'collector task failed with error: {e}')
|
| 373 |
+
raise e
|
| 374 |
+
await session_state.ner_queue.join()
|
| 375 |
+
|
| 376 |
+
# Update the running_spock with the final data
|
| 377 |
+
running_spock = session_state.get_running_spock()
|
| 378 |
+
running_spock.status = 'completed'
|
| 379 |
+
running_spock.finished_at = datetime.now()
|
| 380 |
+
|
| 381 |
+
# Log the conclusion and update/empty the running_spock
|
| 382 |
+
gr.Info(f'job "{running_spock.setup.term}" finished')
|
| 383 |
+
logger.info(f'job "{running_spock.setup.term}" finished in {running_spock.runtime()}')
|
| 384 |
+
|
| 385 |
+
return session_state
|
| 386 |
+
|
| 387 |
+
@staticmethod
|
| 388 |
+
async def _canceling(session_state: SessionState) -> SessionState:
|
| 389 |
+
"""Cancel all running :class:`asyncio.Task`."""
|
| 390 |
+
running_spock = session_state.get_running_spock()
|
| 391 |
+
gr.Info(f'canceled SpoCK for "{running_spock.setup.term}"')
|
| 392 |
+
|
| 393 |
+
# Update the running_spock
|
| 394 |
+
running_spock.status = 'stopped'
|
| 395 |
+
running_spock.finished_at = datetime.now()
|
| 396 |
+
|
| 397 |
+
# Cancel scraping tasks
|
| 398 |
+
logger.warning("canceling scraping tasks")
|
| 399 |
+
for task in session_state.scp_tasks:
|
| 400 |
+
task.cancel()
|
| 401 |
+
await asyncio.gather(*session_state.scp_tasks, return_exceptions=True)
|
| 402 |
+
|
| 403 |
+
# Cancel named entity recognition tasks
|
| 404 |
+
logger.warning("canceling named entity recognition tasks")
|
| 405 |
+
for task in session_state.ner_tasks:
|
| 406 |
+
task.cancel()
|
| 407 |
+
await asyncio.gather(*session_state.ner_tasks, return_exceptions=True)
|
| 408 |
+
|
| 409 |
+
# Cancel orchestrator task
|
| 410 |
+
logger.warning("canceling orchestrator task")
|
| 411 |
+
session_state.orc_task.cancel()
|
| 412 |
+
await asyncio.gather(session_state.orc_task, return_exceptions=True) # we use return_exceptions=True to avoid raising exceptions due to the subtasks being allready canceled`
|
| 413 |
+
|
| 414 |
+
# Cancel collector task
|
| 415 |
+
logger.warning("canceling collector task")
|
| 416 |
+
session_state.col_task.cancel()
|
| 417 |
+
await asyncio.gather(session_state.col_task, return_exceptions=True) # see remark above
|
| 418 |
+
|
| 419 |
+
return session_state
|
| 420 |
+
|
| 421 |
+
@staticmethod
|
| 422 |
+
def _change_active_spock_number(session_state: SessionState, index: int) -> SessionState:
|
| 423 |
+
logger.debug(f'card clicked={index}')
|
| 424 |
+
session_state.set_active_spock(index=index)
|
| 425 |
+
return session_state
|
| 426 |
+
|
| 427 |
+
# --------------------------------------------------------------------------
|
| 428 |
+
# Events
|
| 429 |
+
# --------------------------------------------------------------------------
|
| 430 |
+
def _event_timer(self):
|
| 431 |
+
self._components['timer'].tick(
|
| 432 |
+
fn=self._feed_cards_to_ui,
|
| 433 |
+
inputs=[self._local_state, self._session_state],
|
| 434 |
+
outputs=self._components['main.cards'],
|
| 435 |
+
).then(
|
| 436 |
+
fn=self._feed_details_to_ui,
|
| 437 |
+
inputs=[self._session_state],
|
| 438 |
+
outputs=self._components['main.details'],
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def _event_choose_llm(self):
|
| 442 |
+
"""Choose LLM model show the correspoding settings."""
|
| 443 |
+
self._components['settings.llm_radio'].change(
|
| 444 |
+
fn=self._show_llm_settings,
|
| 445 |
+
inputs=self._components['settings.llm_radio'],
|
| 446 |
+
outputs=[
|
| 447 |
+
self._components['settings.ollama_group'],
|
| 448 |
+
self._components['settings.openai_group'],
|
| 449 |
+
],
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def _event_settings_ollama(self):
|
| 453 |
+
"""Callback of the ollama settings."""
|
| 454 |
+
self._components['settings.ollama_base_url'].submit(
|
| 455 |
+
fn=self._set_ollama_base_url,
|
| 456 |
+
inputs=self._components['settings.ollama_base_url'],
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def _event_settings_openai(self):
|
| 460 |
+
"""Callback of the openai settings."""
|
| 461 |
+
self._components['settings.openai_api_key'].submit(
|
| 462 |
+
fn=self._set_openai_api_key,
|
| 463 |
+
inputs=self._components['settings.openai_api_key'],
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
def _event_start_spock(self) -> None:
|
| 467 |
+
search_term = self._components['main.search_term']
|
| 468 |
+
start_button = self._components['main.start_button']
|
| 469 |
+
timer = self._components['timer']
|
| 470 |
+
|
| 471 |
+
gr.on(
|
| 472 |
+
triggers=[search_term.submit, start_button.click],
|
| 473 |
+
fn=self._check_llm_settings,
|
| 474 |
+
inputs=self._components['settings.llm_radio'],
|
| 475 |
+
).success(
|
| 476 |
+
fn=self._toggle_button,
|
| 477 |
+
inputs=self._session_state,
|
| 478 |
+
outputs=[
|
| 479 |
+
self._session_state,
|
| 480 |
+
self._components['main.start_button'],
|
| 481 |
+
self._components['main.stop_button'],
|
| 482 |
+
self._components['main.cancel_button'],
|
| 483 |
+
]
|
| 484 |
+
).then(
|
| 485 |
+
fn=self._setup_spock,
|
| 486 |
+
inputs=[
|
| 487 |
+
search_term,
|
| 488 |
+
self._components['settings.llm_radio'],
|
| 489 |
+
self._components['settings.source'],
|
| 490 |
+
self._components['settings.max_docs_src'],
|
| 491 |
+
self._local_state,
|
| 492 |
+
self._session_state,
|
| 493 |
+
],
|
| 494 |
+
outputs=self._session_state
|
| 495 |
+
).then(
|
| 496 |
+
fn=self._setup_asyncio_framework,
|
| 497 |
+
inputs=self._session_state,
|
| 498 |
+
outputs=self._session_state,
|
| 499 |
+
).then(
|
| 500 |
+
fn=lambda: None, outputs=search_term # Empty the search term in the UI
|
| 501 |
+
).then(
|
| 502 |
+
fn=self._feed_cards_to_ui,
|
| 503 |
+
inputs=[self._local_state, self._session_state],
|
| 504 |
+
outputs=self._components['main.cards'],
|
| 505 |
+
).then(
|
| 506 |
+
fn=lambda: gr.update(active=True), outputs=timer
|
| 507 |
+
).then(
|
| 508 |
+
fn=self._conclusion,
|
| 509 |
+
inputs=self._session_state,
|
| 510 |
+
outputs=self._session_state,
|
| 511 |
+
).then(
|
| 512 |
+
fn=self._feed_cards_to_ui,
|
| 513 |
+
inputs=[self._local_state, self._session_state],
|
| 514 |
+
outputs=self._components['main.cards'],
|
| 515 |
+
).then(
|
| 516 |
+
fn=self._feed_details_to_ui, # called one more time in order to enforce update of the details (regardless of the state of the timer)
|
| 517 |
+
inputs=[self._session_state],
|
| 518 |
+
outputs=self._components['main.details'],
|
| 519 |
+
).then(
|
| 520 |
+
fn=lambda: gr.update(active=False), outputs=timer
|
| 521 |
+
).then(
|
| 522 |
+
fn=self._toggle_button,
|
| 523 |
+
inputs=self._session_state,
|
| 524 |
+
outputs=[
|
| 525 |
+
self._session_state,
|
| 526 |
+
self._components['main.start_button'],
|
| 527 |
+
self._components['main.stop_button'],
|
| 528 |
+
self._components['main.cancel_button'],
|
| 529 |
+
]
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
def _event_stop_spock(self):
|
| 533 |
+
# NOTE: when `stop_button.click` is triggered, the above pipeline (started by `search_term.submit` or
|
| 534 |
+
# `start_button.click`) is still running and awaiting the `_conclusion` step to finish. The `stop_button.click`
|
| 535 |
+
# event will cause the `_conclusion` step to terminate, after which the subsequent steps will still be executed;
|
| 536 |
+
# -> therefore, there is no need to add these steps here.
|
| 537 |
+
self._components['main.stop_button'].click(
|
| 538 |
+
fn=self._show_cancel_button,
|
| 539 |
+
outputs=[
|
| 540 |
+
self._components['main.start_button'],
|
| 541 |
+
self._components['main.stop_button'],
|
| 542 |
+
self._components['main.cancel_button'],
|
| 543 |
+
]
|
| 544 |
+
).then(
|
| 545 |
+
fn=self._canceling,
|
| 546 |
+
inputs=self._session_state,
|
| 547 |
+
outputs=self._session_state,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
def _event_card_click(self):
|
| 551 |
+
for index, crd in enumerate(self._components['main.cards']):
|
| 552 |
+
crd.click(
|
| 553 |
+
fn=self._change_active_spock_number,
|
| 554 |
+
inputs=[self._session_state, gr.Number(value=index, visible=False)],
|
| 555 |
+
outputs=self._session_state,
|
| 556 |
+
).then(
|
| 557 |
+
fn=self._feed_details_to_ui,
|
| 558 |
+
inputs=[self._session_state],
|
| 559 |
+
outputs=self._components['main.details'],
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
def register_events(self):
|
| 563 |
+
"""Register the events."""
|
| 564 |
+
# Setup timer for feed cards and details
|
| 565 |
+
self._event_timer()
|
| 566 |
+
|
| 567 |
+
# Settings events
|
| 568 |
+
self._event_choose_llm()
|
| 569 |
+
self._event_settings_ollama()
|
| 570 |
+
self._event_settings_openai()
|
| 571 |
+
|
| 572 |
+
# Start/Stop events
|
| 573 |
+
self._event_start_spock()
|
| 574 |
+
self._event_stop_spock()
|
| 575 |
+
|
| 576 |
+
# Card click events for showing details
|
| 577 |
+
self._event_card_click()
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
if __name__ == "__main__":
|
| 581 |
+
from vianu.spock.app import App
|
| 582 |
+
|
| 583 |
+
app = App()
|
| 584 |
+
demo = app.make()
|
| 585 |
+
demo.queue().launch(
|
| 586 |
+
favicon_path=app.favicon_path,
|
| 587 |
+
inbrowser=True,
|
| 588 |
+
allowed_paths=[
|
| 589 |
+
str(_ASSETS_PATH.resolve()),
|
| 590 |
+
],
|
| 591 |
+
server_port=GRADIO_SERVER_PORT,
|
| 592 |
+
)
|
vianu/spock/app/formatter.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from vianu.spock.src.base import Document, SpoCK
|
| 5 |
+
from vianu.spock.settings import DATE_FORMAT
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
JOBS_CONTAINER_CARD_TEMPLATE = """
|
| 11 |
+
<div class="card" onclick="cardClickHandler(this)">
|
| 12 |
+
<div class="title">{title} {status}</div>
|
| 13 |
+
<div class="info">Date: {date}</div>
|
| 14 |
+
<div class="info">Sources: {sources}</div>
|
| 15 |
+
<div class="info">#docs: {n_doc} | #adr: {n_adr}</div>
|
| 16 |
+
</div>
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
DETAILS_CONTAINER_TEMPLATE = """
|
| 20 |
+
<div id='details' class='details-container'>
|
| 21 |
+
<div class='items'>{items}</div>
|
| 22 |
+
</div>
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
DETAILS_CONTAINER_ITEM_TEMPLATE = """
|
| 26 |
+
<div class='item'>
|
| 27 |
+
<div class='top'>
|
| 28 |
+
<div class='favicon'><img src='{favicon}' alt='Favicon'></div>
|
| 29 |
+
<div class='title'><a href='{url}'>{title}</a></div>
|
| 30 |
+
</div>
|
| 31 |
+
<div class='bottom'>
|
| 32 |
+
{text}
|
| 33 |
+
</div>
|
| 34 |
+
</div>
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_details_html_items(data: List[Document]):
|
| 39 |
+
"""Get the HTML items for the details container. Each item contains the favicon, title, and the text with the
|
| 40 |
+
highlighted named entities.
|
| 41 |
+
"""
|
| 42 |
+
items = []
|
| 43 |
+
max_title_lenth = 120
|
| 44 |
+
for doc in data:
|
| 45 |
+
items.append(
|
| 46 |
+
DETAILS_CONTAINER_ITEM_TEMPLATE.format(
|
| 47 |
+
favicon=doc.source_favicon_url,
|
| 48 |
+
url=doc.url,
|
| 49 |
+
title=doc.title[:max_title_lenth]
|
| 50 |
+
+ ("..." if len(doc.title) > max_title_lenth else ""),
|
| 51 |
+
text=doc.get_html(),
|
| 52 |
+
details="details",
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
return "\n".join(items)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_details_html(data: List[Document]):
|
| 59 |
+
"""Get the stacked HTML items for each document."""
|
| 60 |
+
if len(data) == 0:
|
| 61 |
+
return "<div>no results available (yet)</div>"
|
| 62 |
+
sorted_data = sorted(data, key=lambda x: (len(x.adverse_reactions), len(x.medicinal_products)), reverse=True)
|
| 63 |
+
items = _get_details_html_items(data=sorted_data)
|
| 64 |
+
return DETAILS_CONTAINER_TEMPLATE.format(items=items)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _get_status_html(status: str) -> str:
|
| 68 |
+
"""Get the HTML for the status."""
|
| 69 |
+
if status == "running":
|
| 70 |
+
return f"<span class='running'>({status.upper()})</span>"
|
| 71 |
+
elif status == "completed":
|
| 72 |
+
return f"<span class='completed'>({status.upper()})</span>"
|
| 73 |
+
elif status == "stopped":
|
| 74 |
+
return f"<span class='stopped'>({status.upper()})</span>"
|
| 75 |
+
else:
|
| 76 |
+
logger.error(f"unknown status: {status.upper()})")
|
| 77 |
+
return '<span>(status unknown)</span>'
|
| 78 |
+
|
| 79 |
+
def get_job_card_html(card_nmbr: int, spock: SpoCK):
|
| 80 |
+
"""Get the HTML for the job card."""
|
| 81 |
+
job = spock.setup
|
| 82 |
+
data = spock.data
|
| 83 |
+
|
| 84 |
+
title = spock.setup.term
|
| 85 |
+
status = _get_status_html(spock.status)
|
| 86 |
+
sources = ", ".join(job.source)
|
| 87 |
+
date = job.submission.strftime(DATE_FORMAT)
|
| 88 |
+
n_doc = len(data)
|
| 89 |
+
n_adr = sum([len(d.adverse_reactions) for d in data])
|
| 90 |
+
return JOBS_CONTAINER_CARD_TEMPLATE.format(
|
| 91 |
+
nmbr=card_nmbr,
|
| 92 |
+
title=title,
|
| 93 |
+
status=status,
|
| 94 |
+
date=date,
|
| 95 |
+
sources=sources,
|
| 96 |
+
n_doc=n_doc,
|
| 97 |
+
n_adr=n_adr,
|
| 98 |
+
)
|
vianu/spock/assets/css/styles.css
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* no footer */
|
| 2 |
+
footer {
|
| 3 |
+
display: none !important;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
/* customize scrollbar in gr.Dataframe */
|
| 7 |
+
::-webkit-scrollbar {
|
| 8 |
+
background: var(--background-fill-primary);
|
| 9 |
+
}
|
| 10 |
+
::-webkit-scrollbar-thumb {
|
| 11 |
+
background-color: var(--border-color-primary);
|
| 12 |
+
border: 4px solid transparent;
|
| 13 |
+
border-radius: 100px;
|
| 14 |
+
background-clip: content-box;
|
| 15 |
+
}
|
| 16 |
+
::-webkit-scrollbar-corner {
|
| 17 |
+
background: var(--background-fill-primary);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.top-row {
|
| 21 |
+
display: flex;
|
| 22 |
+
align-items: center;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
.top-row .image {
|
| 26 |
+
display: flex;
|
| 27 |
+
justify-content: center;
|
| 28 |
+
align-items: center;
|
| 29 |
+
background-color: transparent !important;
|
| 30 |
+
border: none !important;
|
| 31 |
+
box-shadow: none !important;
|
| 32 |
+
padding: 0 !important;
|
| 33 |
+
height: 150px;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
.top-row .title-desc {
|
| 37 |
+
justify-content: center;
|
| 38 |
+
display: block;
|
| 39 |
+
height: 100%;
|
| 40 |
+
background: var(--block-title-background-fill);
|
| 41 |
+
border-radius: var(--radius-lg);
|
| 42 |
+
font: var(--font);
|
| 43 |
+
text-align: center;
|
| 44 |
+
padding: var(--scale-0)
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.top-row .title-desc .title {
|
| 48 |
+
color: var(--block-title-text-color);
|
| 49 |
+
font-size: var(--text-xxl);
|
| 50 |
+
font-weight: var(--weight-extrabold);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.top-row .title-desc .desc {
|
| 54 |
+
padding: 0px !important;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.search-container {
|
| 58 |
+
display: flex;
|
| 59 |
+
align-items: flex-end;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
.search-container .button-not-running,
|
| 64 |
+
.search-container .button-running,
|
| 65 |
+
.search-container .canceling {
|
| 66 |
+
cursor: pointer;
|
| 67 |
+
display: flex;
|
| 68 |
+
justify-content: center;
|
| 69 |
+
align-items: center;
|
| 70 |
+
color: var(--color-grey-100);
|
| 71 |
+
border-radius: var(--radius-md);
|
| 72 |
+
font-size: var(--text-xxl);
|
| 73 |
+
font-weight: var(--weight-extrabold);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.search-container .button-not-running {
|
| 77 |
+
background: linear-gradient(to right, var(--color-pink-500), var(--button-primary-background-fill));
|
| 78 |
+
}
|
| 79 |
+
.search-container .button-running {
|
| 80 |
+
background: var(--color-red-600);
|
| 81 |
+
}
|
| 82 |
+
.search-container .canceling {
|
| 83 |
+
background: var(--color-grey-500);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
.jobs-container {
|
| 87 |
+
display: grid;
|
| 88 |
+
grid-template-columns: repeat(auto-fill, minmax(var(--size-80), 1fr));
|
| 89 |
+
}
|
| 90 |
+
.card {
|
| 91 |
+
justify-self: center;
|
| 92 |
+
cursor: pointer;
|
| 93 |
+
border: none var(--spacing-md) var(--block-title-border-color);
|
| 94 |
+
border-radius: var(--radius-lg);
|
| 95 |
+
width: var(--size-80);
|
| 96 |
+
padding: var(--scale-2);
|
| 97 |
+
font: var(--font);
|
| 98 |
+
background: var(--block-label-background-fill);
|
| 99 |
+
}
|
| 100 |
+
.card:active {
|
| 101 |
+
transform: scale(0.97);
|
| 102 |
+
}
|
| 103 |
+
.card .title {
|
| 104 |
+
font-weight: var(--weight-bold);
|
| 105 |
+
font-size: var(--text-lg);
|
| 106 |
+
color: var(--block-title-text-color);
|
| 107 |
+
margin-bottom: var(--scale-0);
|
| 108 |
+
}
|
| 109 |
+
.card .title .running{
|
| 110 |
+
color: var(--color-red-600);
|
| 111 |
+
font-size: var(--text-md);
|
| 112 |
+
}
|
| 113 |
+
.card .title .stopped{
|
| 114 |
+
color: var(--color-grey-500);
|
| 115 |
+
font-size: var(--text-md);
|
| 116 |
+
}
|
| 117 |
+
.card .title .completed{
|
| 118 |
+
color: var(--color-green-700);
|
| 119 |
+
font-size: var(--text-md);
|
| 120 |
+
}
|
| 121 |
+
.card .info {
|
| 122 |
+
color: var(--block-info-text-color);
|
| 123 |
+
margin-bottom: var(--block-label-margin);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.details-container {
|
| 127 |
+
display: flex;
|
| 128 |
+
border-radius: 8px;
|
| 129 |
+
}
|
| 130 |
+
.details-container .title {
|
| 131 |
+
font-weight: var(--weight-bold);
|
| 132 |
+
font-size: var(--text-xl);
|
| 133 |
+
padding-left: var(--scale-0);
|
| 134 |
+
margin: var(--block-label-margin) 0 var(--block-label-margin) 0;
|
| 135 |
+
}
|
| 136 |
+
.details-container .info {
|
| 137 |
+
color: var(--block-info-text-color);
|
| 138 |
+
padding-left: var(--scale-0);
|
| 139 |
+
margin: var(--block-label-margin) 0 var(--block-label-margin) 0;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.details-container .items {
|
| 143 |
+
margin-top: var(--scale-4);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
.details-container .items .item {
|
| 147 |
+
padding: var(--scale-0);
|
| 148 |
+
margin-bottom: var(--scale-0);
|
| 149 |
+
border: solid var(--block-label-border-width) var(--block-label-border-color);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
.details-container .items .item .top{
|
| 153 |
+
display: flex;
|
| 154 |
+
align-items: center;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
.details-container .items .item .top .favicon{
|
| 158 |
+
display: flex;
|
| 159 |
+
align-items: center;
|
| 160 |
+
margin-right: var(--scale-0);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
.details-container .items .item .top .favicon img{
|
| 164 |
+
height: 1.5em;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
.details-container .items .item .top .title{
|
| 168 |
+
display: flex;
|
| 169 |
+
align-items: center;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
.details-container .items .item .bottom{
|
| 173 |
+
display: flex;
|
| 174 |
+
align-items: center;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
.details-container .items .item .bottom .ner.mp,
|
| 178 |
+
.details-container .items .item .bottom .ner.adr {
|
| 179 |
+
padding: 0px 5px;
|
| 180 |
+
border-radius: 4px;
|
| 181 |
+
font-weight: var(--weight-bold);
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.details-container .items .item .bottom .ner.mp {
|
| 185 |
+
color: var(--color-grey-100);
|
| 186 |
+
background-color: var(--color-purple-700);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
.details-container .items .item .bottom .ner.adr {
|
| 190 |
+
background-color: var(--color-pink);
|
| 191 |
+
}
|
vianu/spock/assets/head/scripts.html
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<script>
|
| 2 |
+
function cardClickHandler(card) {
|
| 3 |
+
var cards = document.getElementsByClassName("card");
|
| 4 |
+
for (let c of cards) {
|
| 5 |
+
c.style.borderStyle = "none";
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
card.style.borderStyle = "solid";
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
</script>
|
vianu/spock/assets/images/favicon.png
ADDED
|
|
vianu/spock/assets/images/spock_logo.png
ADDED
|
vianu/spock/assets/images/spock_logo_circular.png
ADDED
|
vianu/spock/launch_demo_app.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from vianu import LOG_FMT
|
| 5 |
+
from vianu.spock.settings import LOG_LEVEL, GRADIO_SERVER_PORT
|
| 6 |
+
from vianu.spock.app import App
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(level=LOG_LEVEL.upper(), format=LOG_FMT)
|
| 9 |
+
os.environ["GRADIO_SERVER_PORT"] = str(GRADIO_SERVER_PORT)
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
app = App()
|
| 13 |
+
demo = app.make()
|
| 14 |
+
demo.queue().launch(
|
| 15 |
+
favicon_path=app.favicon_path,
|
| 16 |
+
inbrowser=True,
|
| 17 |
+
allowed_paths=app.allowed_paths,
|
| 18 |
+
)
|
vianu/spock/launch_demo_pipeline.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import Namespace
|
| 2 |
+
import asyncio
|
| 3 |
+
|
| 4 |
+
from vianu.spock.__main__ import main
|
| 5 |
+
from vianu.spock.settings import SCRAPING_SOURCES, MAX_DOCS_SRC
|
| 6 |
+
from vianu.spock.settings import N_SCP_TASKS, N_NER_TASKS
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_ARGS = {
|
| 10 |
+
'term': 'ibuprofen',
|
| 11 |
+
'max_docs_src': MAX_DOCS_SRC,
|
| 12 |
+
'source': SCRAPING_SOURCES,
|
| 13 |
+
'model': 'llama',
|
| 14 |
+
'n_scp_tasks': N_SCP_TASKS,
|
| 15 |
+
'n_ner_tasks': N_NER_TASKS,
|
| 16 |
+
'log_level': 'DEBUG',
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
asyncio.run(main(args_=Namespace(**_ARGS), save=False))
|
vianu/spock/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohttp==3.11.11
|
| 2 |
+
beautifulsoup4==4.12.3
|
| 3 |
+
dacite==1.8.1
|
| 4 |
+
gradio==5.10.0
|
| 5 |
+
numpy==2.2.1
|
| 6 |
+
pymupdf==1.25.1
|
| 7 |
+
python-dotenv==1.0.1
|
| 8 |
+
openai==1.59.5
|
| 9 |
+
defusedxml==0.7.1
|
vianu/spock/settings.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# General settings
|
| 2 |
+
LOG_LEVEL = 'DEBUG'
|
| 3 |
+
N_CHAR_DOC_ID = 12
|
| 4 |
+
FILE_PATH = "/tmp/spock/" # nosec
|
| 5 |
+
FILE_NAME = "spock"
|
| 6 |
+
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
| 7 |
+
|
| 8 |
+
# Gradio app settings
|
| 9 |
+
GRADIO_APP_NAME = "SpoCK"
|
| 10 |
+
GRADIO_SERVER_PORT=7868
|
| 11 |
+
GRADIO_MAX_JOBS = 5
|
| 12 |
+
GRADIO_UPDATE_INTERVAL = 2
|
| 13 |
+
|
| 14 |
+
# Scraping settings
|
| 15 |
+
SCRAPING_SOURCES = ['pubmed', 'ema', 'mhra', 'fda']
|
| 16 |
+
MAX_CHUNK_SIZE = 500
|
| 17 |
+
MAX_DOCS_SRC = 5
|
| 18 |
+
N_SCP_TASKS = 2
|
| 19 |
+
|
| 20 |
+
PUBMED_ESEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
|
| 21 |
+
PUBMED_EFETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
|
| 22 |
+
PUBMED_DB = 'pubmed'
|
| 23 |
+
PUBMED_BATCH_SIZE = 20
|
| 24 |
+
|
| 25 |
+
# NER settings
|
| 26 |
+
LARGE_LANGUAGE_MODELS = ['llama', 'openai']
|
| 27 |
+
MAX_TOKENS = 128.000
|
| 28 |
+
LLAMA_MODEL='llama3.2'
|
| 29 |
+
OPENAI_MODEL='gpt-4o'
|
| 30 |
+
N_NER_TASKS = 2
|
| 31 |
+
OLLAMA_BASE_URL_ENV_NAME = "OLLAMA_BASE_URL"
|
| 32 |
+
OPENAI_API_KEY_ENV_NAME = "OPENAI_API_KEY"
|
vianu/spock/src/__init__.py
ADDED
|
File without changes
|
vianu/spock/src/base.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from argparse import Namespace
|
| 3 |
+
from dataclasses import dataclass, asdict, field
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
from hashlib import sha256
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import dacite
|
| 12 |
+
import numpy as np
|
| 13 |
+
from typing import List, Self
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class Serializable:
|
| 20 |
+
"""Abstract base class for all dataclasses that can be serialized to a dictionary."""
|
| 21 |
+
def to_dict(self) -> dict:
|
| 22 |
+
"""Converts the object to a dictionary."""
|
| 23 |
+
return asdict(self)
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def from_dict(cls, dict_: dict) -> Self:
|
| 27 |
+
"""Creates an object from a dictionary."""
|
| 28 |
+
return dacite.from_dict(
|
| 29 |
+
data_class=cls,
|
| 30 |
+
data=dict_,
|
| 31 |
+
config=dacite.Config(type_hooks={datetime: datetime.fromisoformat})
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Identicator(ABC):
|
| 37 |
+
"""Abstract base class for entities with customized id.
|
| 38 |
+
|
| 39 |
+
Notes
|
| 40 |
+
The identifier :param:`Identicator.id_` is hashed and enriched with `_id_prefix` if this is
|
| 41 |
+
not present. This means as long as the `id_` begins with `_id_prefix` nothing is done.
|
| 42 |
+
|
| 43 |
+
This behavior aims to allow:
|
| 44 |
+
SubIdenticator(id_='This is the string that identifies the entity')
|
| 45 |
+
|
| 46 |
+
and with _id_prefix='sub' it produces an id_ of the form:
|
| 47 |
+
id_ = 'sub_5d41402abc4b2a76b9719d911017c592'
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
id_: str
|
| 51 |
+
|
| 52 |
+
def __eq__(self, other: object) -> bool:
|
| 53 |
+
if not isinstance(other, Identicator):
|
| 54 |
+
return NotImplemented
|
| 55 |
+
return self.id_ == other.id_
|
| 56 |
+
|
| 57 |
+
def __post_init__(self):
|
| 58 |
+
if not self.id_.startswith(self._id_prefix):
|
| 59 |
+
self.id_ = self._id_prefix + self._hash_id_str()
|
| 60 |
+
|
| 61 |
+
def _hash_id_str(self):
|
| 62 |
+
return sha256(self.id_.encode()).hexdigest()
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def _id_prefix(self):
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass(eq=False)
|
| 71 |
+
class NamedEntity(Identicator, Serializable):
|
| 72 |
+
"""Class for all named entities."""
|
| 73 |
+
text: str = field(default_factory=str)
|
| 74 |
+
class_: str = field(default_factory=str)
|
| 75 |
+
location: List[int] | None = None
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def _id_prefix(self):
|
| 79 |
+
return 'ent_'
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass(eq=False)
|
| 83 |
+
class Document(Identicator, Serializable):
|
| 84 |
+
"""Class containing any document related information."""
|
| 85 |
+
|
| 86 |
+
# mandatory document fields
|
| 87 |
+
text: str
|
| 88 |
+
source: str
|
| 89 |
+
|
| 90 |
+
# additional document fields
|
| 91 |
+
title: str | None = None
|
| 92 |
+
url: str | None = None
|
| 93 |
+
source_url: str | None = None
|
| 94 |
+
source_favicon_url: str | None = None
|
| 95 |
+
language: str | None = None
|
| 96 |
+
publication_date: datetime | None = None
|
| 97 |
+
|
| 98 |
+
# named entities
|
| 99 |
+
medicinal_products: List[NamedEntity] = field(default_factory=list)
|
| 100 |
+
adverse_reactions: List[NamedEntity] = field(default_factory=list)
|
| 101 |
+
|
| 102 |
+
# protected fields
|
| 103 |
+
_html: str | None = None
|
| 104 |
+
_html_hash: str | None = None
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def _id_prefix(self):
|
| 108 |
+
return 'doc_'
|
| 109 |
+
|
| 110 |
+
def remove_named_entity_by_id(self, id_: str) -> None:
|
| 111 |
+
"""Removes a named entity from the document by a given `doc.id_`."""
|
| 112 |
+
self.medicinal_products = [ne for ne in self.medicinal_products if ne.id_ != id_]
|
| 113 |
+
self.adverse_reactions = [ne for ne in self.adverse_reactions if ne.id_ != id_]
|
| 114 |
+
|
| 115 |
+
def _get_html_hash(self) -> str:
|
| 116 |
+
"""Creates a sha256 hash from the named entities' ids. If the sets of named entities have been modified, this
|
| 117 |
+
function will return a different hash.
|
| 118 |
+
"""
|
| 119 |
+
ne_ids = [ne.id_ for ne in self.medicinal_products + self.adverse_reactions]
|
| 120 |
+
html_hash_str = ' '.join(ne_ids)
|
| 121 |
+
return sha256(html_hash_str.encode()).hexdigest()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _get_html(self) -> str:
|
| 125 |
+
"""Creates the HTML representation of the document with highlighted named entities."""
|
| 126 |
+
text = f"<div>{self.text}</div>"
|
| 127 |
+
|
| 128 |
+
# Highlight medicinal products accodring to the css class 'mp'
|
| 129 |
+
mp_template = "<span class='ner mp'>{text} | {class_}</span>"
|
| 130 |
+
for ne in self.medicinal_products:
|
| 131 |
+
text = text.replace(
|
| 132 |
+
ne.text, mp_template.format(text=ne.text, class_=ne.class_)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Highlight adverse drug reactions accodring to the css class 'adr'
|
| 136 |
+
adr_template = "<span class='ner adr'>{text} | {class_}</span>"
|
| 137 |
+
for ne in self.adverse_reactions:
|
| 138 |
+
text = text.replace(
|
| 139 |
+
ne.text, adr_template.format(text=ne.text, class_=ne.class_)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return text
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_html(self) -> str:
|
| 146 |
+
"""Returns the HTML representation of the document with highlighted named entities. This function checks if
|
| 147 |
+
the set of named entities has been modified and updates the HTML representation if necessary."""
|
| 148 |
+
html_hash = self._get_html_hash()
|
| 149 |
+
if self._html is None or html_hash != self._html_hash:
|
| 150 |
+
self._html = self._get_html()
|
| 151 |
+
self._html_hash = html_hash
|
| 152 |
+
return self._html
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@dataclass(eq=False)
|
| 156 |
+
class Setup(Identicator, Serializable):
|
| 157 |
+
"""Class for the pipeline setup (closely related to the CLI arguments)."""
|
| 158 |
+
|
| 159 |
+
# generic options
|
| 160 |
+
log_level: str
|
| 161 |
+
max_docs_src: int
|
| 162 |
+
|
| 163 |
+
# scraping options
|
| 164 |
+
term: str
|
| 165 |
+
source: List[str]
|
| 166 |
+
n_scp_tasks: int
|
| 167 |
+
|
| 168 |
+
# NER options
|
| 169 |
+
model: str
|
| 170 |
+
n_ner_tasks: int
|
| 171 |
+
|
| 172 |
+
# optional fields
|
| 173 |
+
submission: datetime | None = None
|
| 174 |
+
file_path: str | None = None
|
| 175 |
+
file_name: str | None = None
|
| 176 |
+
|
| 177 |
+
def __post_init__(self) -> None:
|
| 178 |
+
super().__post_init__()
|
| 179 |
+
self.submission = datetime.now() if self.submission is None else self.submission
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def _id_prefix(self) -> str:
|
| 183 |
+
return 'stp_'
|
| 184 |
+
|
| 185 |
+
def to_namespace(self) -> Namespace:
|
| 186 |
+
"""Converts the :class:`Setup` object to a :class:`argparse.Namespace` object."""
|
| 187 |
+
return Namespace(**asdict(self))
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def from_namespace(cls, args_: Namespace) -> Self:
|
| 191 |
+
"""Creates a :class:`Setup` object from a :class:`argparse.Namespace` object."""
|
| 192 |
+
args_dict = vars(args_)
|
| 193 |
+
return cls(id_=str(args_dict), **args_dict)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@dataclass
|
| 197 |
+
class QueueItem:
|
| 198 |
+
"""Class for the :class:`asyncio.Queue` items"""
|
| 199 |
+
id_: str
|
| 200 |
+
doc: Document
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@dataclass
|
| 204 |
+
class SpoCK(Identicator, Serializable):
|
| 205 |
+
"""Main class for the SpoCK pipeline mainly containing the job definition and the resulting data."""
|
| 206 |
+
# Generic fields
|
| 207 |
+
status: str | None = None
|
| 208 |
+
started_at: datetime | None = None
|
| 209 |
+
finished_at: datetime | None = None
|
| 210 |
+
|
| 211 |
+
# Pipeline fields
|
| 212 |
+
setup: Setup | None = None
|
| 213 |
+
data: List[Document] = field(default_factory=list)
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def _id_prefix(self) -> str:
|
| 217 |
+
return 'spk_'
|
| 218 |
+
|
| 219 |
+
def runtime(self) -> timedelta | None:
|
| 220 |
+
if self.started_at is not None:
|
| 221 |
+
if self.finished_at is None:
|
| 222 |
+
return datetime.now() - self.started_at
|
| 223 |
+
return self.finished_at - self.started_at
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
SpoCKList = List[SpoCK]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class JSONEncoder(json.JSONEncoder):
|
| 231 |
+
"""Custom JSON encoder for the :class:`Document` class."""
|
| 232 |
+
def default(self, obj):
|
| 233 |
+
if isinstance(obj, datetime):
|
| 234 |
+
return obj.isoformat()
|
| 235 |
+
if isinstance(obj, Document):
|
| 236 |
+
return asdict(obj)
|
| 237 |
+
if isinstance(obj, np.float32):
|
| 238 |
+
return str(obj)
|
| 239 |
+
if isinstance(obj, np.int64):
|
| 240 |
+
return int(obj)
|
| 241 |
+
# Let the base class default method raise the TypeError
|
| 242 |
+
return json.JSONEncoder.default(self, obj)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class FileHandler:
|
| 246 |
+
"""Reads from and write data to a JSON file under a given file path."""
|
| 247 |
+
|
| 248 |
+
_suffix = '.json'
|
| 249 |
+
|
| 250 |
+
def __init__(self, file_path: Path | str) -> None:
|
| 251 |
+
self._file_path = Path(file_path) if isinstance(file_path, str) else file_path
|
| 252 |
+
if not self._file_path.exists():
|
| 253 |
+
os.makedirs(self._file_path)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def read(self, filename: str) -> List[Document]:
|
| 257 |
+
"""Reads the data from a JSON file and casts it into a list of :class:`Document` objects."""
|
| 258 |
+
filename = (self._file_path / filename).with_suffix(self._suffix)
|
| 259 |
+
|
| 260 |
+
logger.info('reading data from file {filename}')
|
| 261 |
+
with open(filename.with_suffix(self._suffix), 'r', encoding="utf-8") as dfile:
|
| 262 |
+
dict_ = json.load(dfile)
|
| 263 |
+
|
| 264 |
+
return SpoCK.from_dict(dict_=dict_)
|
| 265 |
+
|
| 266 |
+
def write(self, file_name: str, spock: SpoCK, add_dt: bool = True) -> None:
|
| 267 |
+
"""Writes the data to a JSON file.
|
| 268 |
+
|
| 269 |
+
If `add_dt=True`, the filename is `{file_name}_%Y%m%d%H%M%S.json`.
|
| 270 |
+
"""
|
| 271 |
+
if add_dt:
|
| 272 |
+
file_name = f'{file_name}_{datetime.now().strftime("%Y%m%d%H%M%S")}'
|
| 273 |
+
file_name = (self._file_path / file_name).with_suffix(self._suffix)
|
| 274 |
+
|
| 275 |
+
logger.info(f'writing data to file {file_name}')
|
| 276 |
+
with open(file_name, 'w', encoding="utf-8") as dfile:
|
| 277 |
+
json.dump(spock.to_dict(), dfile, cls=JSONEncoder)
|
vianu/spock/src/cli.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI for SpoCK
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from typing import Sequence
|
| 6 |
+
|
| 7 |
+
from vianu.spock.settings import LOG_LEVEL, FILE_NAME, FILE_PATH, MAX_DOCS_SRC
|
| 8 |
+
from vianu.spock.settings import SCRAPING_SOURCES, N_SCP_TASKS
|
| 9 |
+
from vianu.spock.settings import N_NER_TASKS, LARGE_LANGUAGE_MODELS
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_args(args_: Sequence) -> argparse.Namespace:
|
| 13 |
+
parser = argparse.ArgumentParser(description="SpoCK", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 14 |
+
|
| 15 |
+
# Add generic options
|
| 16 |
+
gen_gp = parser.add_argument_group('generic')
|
| 17 |
+
gen_gp.add_argument("--log-level", metavar='', type=str, default=LOG_LEVEL, help='log level')
|
| 18 |
+
gen_gp.add_argument("--file-path", metavar='', type=str, default=FILE_PATH, help='path for storing results')
|
| 19 |
+
gen_gp.add_argument("--file-name", metavar='', type=str, default=FILE_NAME, help='filename for storing results')
|
| 20 |
+
gen_gp.add_argument("--max-docs-src", metavar='', type=int, default=MAX_DOCS_SRC, help='maximum number of documents per source')
|
| 21 |
+
|
| 22 |
+
# Add scraping group
|
| 23 |
+
scp_gp = parser.add_argument_group('scraping')
|
| 24 |
+
scp_gp.add_argument('--source', '-s', type=str, action='append', choices=SCRAPING_SOURCES, help='data sources for scraping')
|
| 25 |
+
scp_gp.add_argument('--term', '-t', metavar='', type=str, help='search term')
|
| 26 |
+
scp_gp.add_argument('--n-scp-tasks', metavar='', type=int, default=N_SCP_TASKS, help='number of async scraping tasks')
|
| 27 |
+
|
| 28 |
+
# Add NER group
|
| 29 |
+
ner_gp = parser.add_argument_group('ner')
|
| 30 |
+
ner_gp.add_argument('--model', '-m', type=str, choices=LARGE_LANGUAGE_MODELS, default='llama', help='NER model')
|
| 31 |
+
ner_gp.add_argument('--n-ner-tasks', metavar='', type=int, default=N_NER_TASKS, help='number of async ner tasks')
|
| 32 |
+
|
| 33 |
+
return parser.parse_args(args_)
|
vianu/spock/src/ner.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import aiohttp
|
| 3 |
+
from argparse import Namespace
|
| 4 |
+
import asyncio
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from openai import AsyncOpenAI
|
| 12 |
+
|
| 13 |
+
from vianu.spock.src.base import NamedEntity, QueueItem # noqa: F401
|
| 14 |
+
from vianu.spock.settings import N_CHAR_DOC_ID, LLAMA_MODEL, OPENAI_MODEL
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
NAMED_ENTITY_PROMPT = """
|
| 20 |
+
You are an expert in Natural Language Processing. Your task is to identify named entities (NER) in a given text.
|
| 21 |
+
You will focus on the following entities: adverse drug reaction (entity type: ADR), medicinal product (entity type: MP).
|
| 22 |
+
Once you identified all named entities of the above types, you return them as a Python list of tuples of the form (text, type).
|
| 23 |
+
It is important to only provide the Python list as your output, without any additional explanations or text.
|
| 24 |
+
In addition, make sure that the named entity texts are exact copies of the original text segment
|
| 25 |
+
|
| 26 |
+
Example 1:
|
| 27 |
+
Input:
|
| 28 |
+
"The most commonly reported side effects of dafalgan include headache, nausea, and fatigue."
|
| 29 |
+
|
| 30 |
+
Output:
|
| 31 |
+
[("dafalgan", "MP"), ("headache", "ADR"), ("nausea", "ADR"), ("fatigue", "ADR")]
|
| 32 |
+
|
| 33 |
+
Example 2:
|
| 34 |
+
Input:
|
| 35 |
+
"Patients taking acetaminophen or naproxen have reported experiencing skin rash, dry mouth, and difficulty breathing after taking this medication. In rare cases, seizures have also been observed."
|
| 36 |
+
|
| 37 |
+
Output:
|
| 38 |
+
[("acetaminophen", "MP"), ("naproxen", "MP"), ("skin rash", "ADR"), ("dry mouth", "ADR"), ("difficulty breathing", "ADR"), ("seizures", "ADR")]
|
| 39 |
+
|
| 40 |
+
Example 3:
|
| 41 |
+
Input:
|
| 42 |
+
"There are reported side effects as dizziness, stomach upset, and in some instances, temporary memory loss. These are mainly observed after taking Amitiza (lubiprostone) or Trulance (plecanatide)."
|
| 43 |
+
|
| 44 |
+
Output:
|
| 45 |
+
[("dizziness", "ADR"), ("stomach upset", "ADR"), ("temporary memory loss", "ADR"), ("Amitiza (lubiprostone)", "MP"), ("Trulance (plecanatide)", "MP")]
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class NER(ABC):
|
| 50 |
+
|
| 51 |
+
_named_entity_pattern = re.compile(r'\("([^"]+)",\s?"(MP|ADR)"\)')
|
| 52 |
+
|
| 53 |
+
def __init__(self):
|
| 54 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def _get_loc_of_subtext(text: str, subtext: str) -> List[int] | None:
|
| 59 |
+
"""Get the location of a subtext in a text."""
|
| 60 |
+
pos = text.find(subtext)
|
| 61 |
+
if pos == -1:
|
| 62 |
+
return None
|
| 63 |
+
return [pos, pos + len(subtext)]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _add_loc_for_named_entities(self, text: str, named_entities: List[NamedEntity]) -> None:
|
| 67 |
+
txt_low = text.lower()
|
| 68 |
+
for ne in named_entities:
|
| 69 |
+
ne_txt_low = ne.text.lower()
|
| 70 |
+
loc = self._get_loc_of_subtext(text=txt_low, subtext=ne_txt_low)
|
| 71 |
+
|
| 72 |
+
if loc is not None:
|
| 73 |
+
ne.location = loc
|
| 74 |
+
ne.text = text[loc[0]:loc[1]]
|
| 75 |
+
else:
|
| 76 |
+
self.logger.warning(f'could not find location for named entity "{ne.text}" of class "{ne.class_}"')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def _get_messages(text: str) -> List[dict]:
|
| 81 |
+
text = f'Process the following input text: "{text}"'
|
| 82 |
+
return [
|
| 83 |
+
{
|
| 84 |
+
"role": "system",
|
| 85 |
+
"content": NAMED_ENTITY_PROMPT,
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"role": "user",
|
| 89 |
+
"content": text,
|
| 90 |
+
},
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@abstractmethod
|
| 95 |
+
async def _get_ner_model_answer(self, text: str) -> str:
|
| 96 |
+
raise NotImplementedError('OpenAINER._get_ner_model_answer is not implemented yet')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def apply(self, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> None:
|
| 100 |
+
"""Apply NER to a text received from input queue and put the results in an output queue."""
|
| 101 |
+
|
| 102 |
+
while True:
|
| 103 |
+
# Get text from input queue
|
| 104 |
+
item = await queue_in.get() # type: QueueItem
|
| 105 |
+
|
| 106 |
+
# Check stopping condition
|
| 107 |
+
if item is None:
|
| 108 |
+
queue_in.task_done()
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
# Get the model response with named entities
|
| 112 |
+
id_ = item.id_
|
| 113 |
+
doc = item.doc
|
| 114 |
+
self.logger.debug(f'starting ner for item.id_={id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]})')
|
| 115 |
+
try:
|
| 116 |
+
text = doc.text
|
| 117 |
+
content = await self._get_ner_model_answer(text=text)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
self.logger.error(f'error during ner for item.id_={item.id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]}): {e}')
|
| 120 |
+
queue_in.task_done()
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# Parse the model answer and remove duplicates
|
| 124 |
+
ne_list = re.findall(self._named_entity_pattern, content)
|
| 125 |
+
ne_list = list(set(ne_list))
|
| 126 |
+
|
| 127 |
+
# Create list of NamedEntity objects
|
| 128 |
+
named_entities = []
|
| 129 |
+
for ne in ne_list:
|
| 130 |
+
try:
|
| 131 |
+
txt, cls_ = ne
|
| 132 |
+
named_entities.append(NamedEntity(id_=f'{text} {txt} {cls_}', text=txt, class_=cls_))
|
| 133 |
+
except Exception as e:
|
| 134 |
+
self.logger.error(f'error during creation of `NamedEntity` using {ne}: {e}')
|
| 135 |
+
|
| 136 |
+
# Add locations to named entities and remove those without location
|
| 137 |
+
self._add_loc_for_named_entities(text=text, named_entities=named_entities)
|
| 138 |
+
named_entities = [ne for ne in named_entities if ne.location is not None]
|
| 139 |
+
|
| 140 |
+
# Assign named entities to the document
|
| 141 |
+
ne_mp = [ne for ne in named_entities if ne.class_ == 'MP']
|
| 142 |
+
ne_adr = [ne for ne in named_entities if ne.class_ == 'ADR']
|
| 143 |
+
self.logger.debug(f'found #mp={len(ne_mp)} and #adr={len(ne_adr)} for item.id_={id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]})')
|
| 144 |
+
doc.medicinal_products = ne_mp
|
| 145 |
+
doc.adverse_reactions = ne_adr
|
| 146 |
+
|
| 147 |
+
# Put the document in the output queue
|
| 148 |
+
await queue_out.put(item)
|
| 149 |
+
queue_in.task_done()
|
| 150 |
+
self.logger.info(f'finished NER task for item.id_={id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]})')
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class OllamaNER(NER):
|
| 154 |
+
|
| 155 |
+
def __init__(self, base_url: str, model: str):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self._base_url = base_url
|
| 158 |
+
self._model = model
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _get_http_data(self, text: str, stream: bool = False) -> dict:
|
| 162 |
+
messages = self._get_messages(text=text)
|
| 163 |
+
data = {
|
| 164 |
+
"model": self._model,
|
| 165 |
+
"messages": messages,
|
| 166 |
+
"stream": stream,
|
| 167 |
+
}
|
| 168 |
+
return data
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
async def _get_ner_model_answer(self, text: str) -> str:
|
| 172 |
+
data = self._get_http_data(text=text)
|
| 173 |
+
async with aiohttp.ClientSession() as session:
|
| 174 |
+
url = f'{self._base_url}/api/chat/'
|
| 175 |
+
async with session.post(url, json=data) as response:
|
| 176 |
+
response.raise_for_status()
|
| 177 |
+
resp_json = await response.json()
|
| 178 |
+
content = resp_json['message']['content']
|
| 179 |
+
return content
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class OpenAINER(NER):
|
| 183 |
+
|
| 184 |
+
def __init__(self, api_key: str, model: str):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self._model = model
|
| 187 |
+
self._client = AsyncOpenAI(api_key=api_key)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
async def _get_ner_model_answer(self, text: str) -> str:
|
| 191 |
+
messages = self._get_messages(text=text)
|
| 192 |
+
chat_completion = await self._client.chat.completions.create(
|
| 193 |
+
messages=messages,
|
| 194 |
+
model=self._model,
|
| 195 |
+
)
|
| 196 |
+
return chat_completion.choices[0].message.content
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def create_tasks(args_: Namespace, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> List[asyncio.Task]:
|
| 200 |
+
"""Create asyncio NER tasks."""
|
| 201 |
+
n_ner_tasks = args_.n_ner_tasks
|
| 202 |
+
model = args_.model
|
| 203 |
+
|
| 204 |
+
if model == 'llama':
|
| 205 |
+
base_url = os.environ.get('OLLAMA_BASE_URL')
|
| 206 |
+
if base_url is None:
|
| 207 |
+
raise EnvironmentError("The ollama endpoint must be set by the OLLAMA_ENDPOINT environment variable")
|
| 208 |
+
ner = OllamaNER(base_url=base_url, model=LLAMA_MODEL)
|
| 209 |
+
|
| 210 |
+
elif model == 'openai':
|
| 211 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 212 |
+
if api_key is None:
|
| 213 |
+
raise EnvironmentError("The api_key for the OpenAI client must be set by the OPENAI_API_KEY environment variable")
|
| 214 |
+
ner = OpenAINER(api_key=api_key, model=OPENAI_MODEL)
|
| 215 |
+
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError(f"unknown ner model '{args_.model}'")
|
| 218 |
+
|
| 219 |
+
logger.info(f'setting up {n_ner_tasks} NER task(s)')
|
| 220 |
+
tasks = [asyncio.create_task(ner.apply(queue_in=queue_in, queue_out=queue_out)) for _ in range(n_ner_tasks)]
|
| 221 |
+
return tasks
|
vianu/spock/src/scraping.py
ADDED
|
@@ -0,0 +1,922 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module for scraping data from different sources.
|
| 2 |
+
|
| 3 |
+
The module contains three main classes:
|
| 4 |
+
- :class:`Scraper`: Abstract base class for scraping data from different sources
|
| 5 |
+
- :class:`PubmedScraper`: Class for scraping data from the PubMed database
|
| 6 |
+
- :class:`EMAScraper`: Class for scraping data from the European Medicines Agency
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from argparse import Namespace
|
| 11 |
+
import asyncio
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from io import BytesIO
|
| 15 |
+
import logging
|
| 16 |
+
import re
|
| 17 |
+
from typing import List
|
| 18 |
+
import xml.etree.ElementTree as ET # nosec
|
| 19 |
+
|
| 20 |
+
import aiohttp
|
| 21 |
+
from bs4 import BeautifulSoup
|
| 22 |
+
from bs4.element import Tag
|
| 23 |
+
import defusedxml.ElementTree as DET
|
| 24 |
+
import numpy as np
|
| 25 |
+
import pymupdf
|
| 26 |
+
|
| 27 |
+
from vianu.spock.src.base import Document, QueueItem # noqa: F401
|
| 28 |
+
from vianu.spock.settings import MAX_CHUNK_SIZE, SCRAPING_SOURCES
|
| 29 |
+
from vianu.spock.settings import PUBMED_ESEARCH_URL, PUBMED_DB, PUBMED_EFETCH_URL, PUBMED_BATCH_SIZE
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Scraper(ABC):
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
|
| 41 |
+
"""Main function for scraping data from a source.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
- args_: the arguments for the spock pipeline
|
| 45 |
+
- queue_out: the output queue for the scraped data
|
| 46 |
+
"""
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def split_text_into_chunks(text: str, chunk_size: int = MAX_CHUNK_SIZE, separator: str = ' ') -> List[str]:
|
| 51 |
+
"""Split a text into chunks of a given max size."""
|
| 52 |
+
words = text.split(separator)
|
| 53 |
+
N = len(words)
|
| 54 |
+
s = min(chunk_size, N)
|
| 55 |
+
n = N // s
|
| 56 |
+
bnd = [round(i) for i in np.linspace(0, 1, n+1) * N]
|
| 57 |
+
|
| 58 |
+
chunks = [separator.join(words[start:stop]) for start, stop in zip(bnd[:-1], bnd[1:])]
|
| 59 |
+
return chunks
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
async def _aiohttp_get_html(url: str, headers=None) -> str:
|
| 63 |
+
"""Get the content of a given URL by an aiohttp GET request."""
|
| 64 |
+
async with aiohttp.ClientSession(headers=headers) as session:
|
| 65 |
+
async with session.get(url=url) as response:
|
| 66 |
+
response.raise_for_status()
|
| 67 |
+
text = await response.text()
|
| 68 |
+
return text
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class PubmedEntrezHistoryParams:
|
| 73 |
+
"""Class for optimizing Pubmed database retrieval for large numbers of documents.
|
| 74 |
+
|
| 75 |
+
An example can be found here:
|
| 76 |
+
https://www.ncbi.nlm.nih.gov/books/n/helpeutils/chapter3/#chapter3.Application_3_Retrieving_large
|
| 77 |
+
"""
|
| 78 |
+
web: str
|
| 79 |
+
key: str
|
| 80 |
+
count: int
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class PubmedScraper(Scraper):
|
| 84 |
+
"""Class for scraping data from the PubMed database.
|
| 85 |
+
|
| 86 |
+
The scraper uses the Pubmed API to search for relevant documents. From the list of results it creates a list of
|
| 87 |
+
:class:`Document` objects by the following main steps:
|
| 88 |
+
- Extract all PubmedArticle elements from the search results (other types are ignored)
|
| 89 |
+
- Extract the AbstractText from the PubmedArticle (if there is no abstract, the document is ignored)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
_source = 'pubmed'
|
| 93 |
+
_source_url = 'https://pubmed.ncbi.nlm.nih.gov/'
|
| 94 |
+
_source_favicon_url = 'https://www.ncbi.nlm.nih.gov/favicon.ico'
|
| 95 |
+
|
| 96 |
+
_robots_txt_url = 'https://www.ncbi.nlm.nih.gov/robots.txt'
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def _get_entrez_history_params(text: str) -> PubmedEntrezHistoryParams:
|
| 100 |
+
"""Retrieving the entrez history parameters for optimized search when requesting large numbers of documents.
|
| 101 |
+
An example can be found here:
|
| 102 |
+
https://www.ncbi.nlm.nih.gov/books/n/helpeutils/chapter3/#chapter3.Application_3_Retrieving_large
|
| 103 |
+
"""
|
| 104 |
+
web = re.search(r'<WebEnv>(\S+)</WebEnv>', text).group(1)
|
| 105 |
+
key = re.search(r'<QueryKey>(\d+)</QueryKey>', text).group(1)
|
| 106 |
+
count = int(re.search(r'<Count>(\d+)</Count>', text).group(1))
|
| 107 |
+
return PubmedEntrezHistoryParams(web=web, key=key, count=count)
|
| 108 |
+
|
| 109 |
+
async def _pubmed_esearch(self, term: str) -> str:
|
| 110 |
+
"""Search the Pubmed database with a given term and POST the results to entrez history server."""
|
| 111 |
+
url = f'{PUBMED_ESEARCH_URL}?db={PUBMED_DB}&term={term}&usehistory=y'
|
| 112 |
+
self.logger.debug(f'search pubmed database with url={url}')
|
| 113 |
+
esearch = await self._aiohttp_get_html(url=url)
|
| 114 |
+
return esearch
|
| 115 |
+
|
| 116 |
+
async def _pubmed_efetch(self, params: PubmedEntrezHistoryParams, max_docs_src: int) -> List[str]:
|
| 117 |
+
"""Retrieve the relevant documents from the entrez history server."""
|
| 118 |
+
# Reduce the number of documents to be retrieved for efficiency
|
| 119 |
+
N = min(max_docs_src, int(params.count))
|
| 120 |
+
if N < params.count:
|
| 121 |
+
self.logger.warning(f'from the total number of documents={params.count} only {N} will be retrieved')
|
| 122 |
+
|
| 123 |
+
# Iterate over the batches of documents (with fixed batch size)
|
| 124 |
+
batch_size = min(params.count, PUBMED_BATCH_SIZE)
|
| 125 |
+
self.logger.debug(f'fetch #docs={N} in {N // batch_size + 1} batch(es) of size <= {batch_size}')
|
| 126 |
+
batches = []
|
| 127 |
+
for retstart in range(0, N, batch_size):
|
| 128 |
+
|
| 129 |
+
# Prepare URL for retrieving next batch of documents but stop if the maximum number is reached
|
| 130 |
+
retmax = min(max_docs_src - len(batches)*batch_size, batch_size)
|
| 131 |
+
url = f'{PUBMED_EFETCH_URL}?db={PUBMED_DB}&WebEnv={params.web}&query_key={params.key}&retstart={retstart}&retmax={retmax}'
|
| 132 |
+
self.logger.debug(f'fetch documents with url={url}')
|
| 133 |
+
|
| 134 |
+
# Fetch the documents
|
| 135 |
+
efetch = await self._aiohttp_get_html(url=url)
|
| 136 |
+
batches.append(efetch)
|
| 137 |
+
return batches
|
| 138 |
+
|
| 139 |
+
def _extract_medline_citation(self, element: ET.Element) -> ET.Element | None:
|
| 140 |
+
"""Extract the MedlineCitation element from a PubmedArticle element."""
|
| 141 |
+
# Find and extract the MedlineCitation element
|
| 142 |
+
citation = element.find('MedlineCitation')
|
| 143 |
+
if citation is None:
|
| 144 |
+
self.logger.warning('no "MedlineCitation" element found')
|
| 145 |
+
return None
|
| 146 |
+
return citation
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def _extract_pmid(element: ET.Element) -> str | None:
|
| 150 |
+
"""Extract the PMID from a MedlineCitation element."""
|
| 151 |
+
pmid = element.find('PMID')
|
| 152 |
+
return pmid.text if pmid is not None else None
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _extract_article(element: ET.Element) -> ET.Element | None:
|
| 156 |
+
"""Extract the article element from a PubmedArticle element."""
|
| 157 |
+
# Find and extract the Article element
|
| 158 |
+
article = element.find('Article')
|
| 159 |
+
return article
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def _extract_title(article: ET.Element) -> str | None:
|
| 163 |
+
"""Extract the title from an Article element."""
|
| 164 |
+
title = article.find('ArticleTitle')
|
| 165 |
+
return title.text if title is not None else None
|
| 166 |
+
|
| 167 |
+
@staticmethod
|
| 168 |
+
def _extract_abstract(article: ET.Element) -> str | None:
|
| 169 |
+
"""Extract the abstract from an Article element."""
|
| 170 |
+
separator = '\n\n'
|
| 171 |
+
abstract = article.find('Abstract')
|
| 172 |
+
if abstract is not None:
|
| 173 |
+
abstract = separator.join([a.text for a in abstract.findall('AbstractText') if a.text is not None])
|
| 174 |
+
return abstract
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def _extract_language(article: ET.Element) -> str | None:
|
| 178 |
+
"""Extract the language from an Article element."""
|
| 179 |
+
language = article.find('Language')
|
| 180 |
+
return language.text if language is not None else None
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def _extract_date(article: ET.Element) -> datetime | None:
|
| 184 |
+
"""Extract the publication date from an Article element."""
|
| 185 |
+
date = article.find('ArticleDate')
|
| 186 |
+
if date is None:
|
| 187 |
+
return None
|
| 188 |
+
year = int(date.find('Year').text)
|
| 189 |
+
month = int(date.find('Month').text)
|
| 190 |
+
day = int(date.find('Day').text)
|
| 191 |
+
return datetime(year=year, month=month, day=day)
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def _extract_publication_types(article: ET.Element) -> List[str]:
|
| 195 |
+
"""Extract the publication types from an Article element."""
|
| 196 |
+
return [t.text for t in article.find('PublicationTypeList').findall('PublicationType')]
|
| 197 |
+
|
| 198 |
+
def _parse_pubmed_articles(self, batches: List[str]) -> List[Document]:
|
| 199 |
+
"""Parse batches of ET.Elements into a single list of Document objects"""
|
| 200 |
+
data = []
|
| 201 |
+
for ib, text in enumerate(batches):
|
| 202 |
+
pubmed_articles = DET.fromstring(text).findall('PubmedArticle')
|
| 203 |
+
self.logger.debug(f'found #articles={len(pubmed_articles)} in batch {ib}')
|
| 204 |
+
for ie, element in enumerate(pubmed_articles):
|
| 205 |
+
self.logger.debug(f'parsing PubmedArticle {ie} of batch {ib}')
|
| 206 |
+
# Extract MedlineCitation and its PMID from PubmedArticle
|
| 207 |
+
citation = self._extract_medline_citation(element=element)
|
| 208 |
+
if citation is None:
|
| 209 |
+
self.logger.debug(f'no citation found in PubmedArticle {ie} of batch {ib}')
|
| 210 |
+
continue
|
| 211 |
+
pmid = self._extract_pmid(element=citation)
|
| 212 |
+
|
| 213 |
+
# Extract the Article element from the PubmedArticle
|
| 214 |
+
article = self._extract_article(element=citation)
|
| 215 |
+
if article is None:
|
| 216 |
+
self.logger.debug(f'no article found in PubmedArticle {ie} of batch {ib}')
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
# Extract the relevant information from the Article element
|
| 220 |
+
title = self._extract_title(article=article)
|
| 221 |
+
text = self._extract_abstract(article=article)
|
| 222 |
+
if text is None:
|
| 223 |
+
self.logger.debug(f'no abstract found in PubmedArticle {ie} of batch {ib}')
|
| 224 |
+
continue
|
| 225 |
+
language = self._extract_language(article=article)
|
| 226 |
+
publication_date = self._extract_date(article=article)
|
| 227 |
+
|
| 228 |
+
# Split long texts into chunks
|
| 229 |
+
texts = self.split_text_into_chunks(text=text)
|
| 230 |
+
|
| 231 |
+
# Create the Document object(s)
|
| 232 |
+
for txt in texts:
|
| 233 |
+
document = Document(
|
| 234 |
+
id_=f'{self._source_url} {title} {txt} {language} {publication_date}',
|
| 235 |
+
text=txt,
|
| 236 |
+
source=self._source,
|
| 237 |
+
title=title,
|
| 238 |
+
url=f'{self._source_url}{pmid}/',
|
| 239 |
+
source_url=self._source_url,
|
| 240 |
+
source_favicon_url=self._source_favicon_url,
|
| 241 |
+
language=language,
|
| 242 |
+
publication_date=publication_date,
|
| 243 |
+
)
|
| 244 |
+
data.append(document)
|
| 245 |
+
self.logger.debug(f'parsed #docs={len(data)} from #batches={len(batches)}')
|
| 246 |
+
return data
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
|
| 250 |
+
"""Query and retrieve all PubmedArticle Documents for the given search term.
|
| 251 |
+
|
| 252 |
+
The retrieval is using two main functionalities of the Pubmed API:
|
| 253 |
+
- ESearch: Identify the relevant documents and store them in the entrez history server
|
| 254 |
+
- EFetch: Retrieve the relevant documents from the entrez history server
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
- args_: the arguments for the spock pipeline
|
| 258 |
+
- queue_out: the output queue for the scraped data
|
| 259 |
+
"""
|
| 260 |
+
term = args_.term
|
| 261 |
+
max_docs_src = args_.max_docs_src
|
| 262 |
+
self.logger.debug(f'starting scraping the source={self._source} with term={term}')
|
| 263 |
+
|
| 264 |
+
# Search for relevant documents with a given term
|
| 265 |
+
esearch = await self._pubmed_esearch(term=term)
|
| 266 |
+
|
| 267 |
+
# Retrieve relevant documents in batches
|
| 268 |
+
params = self._get_entrez_history_params(esearch)
|
| 269 |
+
batches = await self._pubmed_efetch(params=params, max_docs_src=max_docs_src)
|
| 270 |
+
|
| 271 |
+
# Parse documents from batches
|
| 272 |
+
documents = self._parse_pubmed_articles(batches=batches)
|
| 273 |
+
documents = documents[:max_docs_src]
|
| 274 |
+
|
| 275 |
+
# Add documents to the queue
|
| 276 |
+
for i, doc in enumerate(documents):
|
| 277 |
+
id_ = f'{self._source}_{i}'
|
| 278 |
+
item = QueueItem(id_=id_, doc=doc)
|
| 279 |
+
await queue_out.put(item)
|
| 280 |
+
|
| 281 |
+
self.logger.info(f'retrieved #docs={len(documents)} in source={self._source} for term={term}')
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@dataclass
|
| 285 |
+
class SearchResults:
|
| 286 |
+
"""Class for storing the search results from different databases."""
|
| 287 |
+
count: int | None
|
| 288 |
+
n_pages: int | None
|
| 289 |
+
items: List[Tag]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class EMAScraper(Scraper):
|
| 293 |
+
"""Class for scraping data from the European Medicines Agency.
|
| 294 |
+
|
| 295 |
+
The scraper uses the same API as the web interface of the EMA to search for relevant documents. From the list of
|
| 296 |
+
results it creates a list of :class:`Document` objects by the following main steps:
|
| 297 |
+
- Search the EMA database (filter for PDF documents only)
|
| 298 |
+
- Extract the text from the PDF documents
|
| 299 |
+
- Use regex to find texts where adverse drug reactions (or similar) are mentioned
|
| 300 |
+
- Return the most recent :class:`Document` objects
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
_source = 'ema'
|
| 304 |
+
_source_url = 'https://www.ema.europa.eu'
|
| 305 |
+
_source_favicon_url = 'https://www.ema.europa.eu/themes/custom/ema_theme/favicon.ico'
|
| 306 |
+
|
| 307 |
+
_pdf_search_template = (
|
| 308 |
+
"https://www.ema.europa.eu/en/search?search_api_fulltext={term}"
|
| 309 |
+
"&f%5B0%5D=ema_search_custom_entity_bundle%3Adocument" # This part is added to only retrieve PDF documents
|
| 310 |
+
"&f%5B1%5D=ema_search_entity_is_document%3ADocument"
|
| 311 |
+
)
|
| 312 |
+
_headers = {
|
| 313 |
+
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
|
| 314 |
+
'Host': 'www.ema.europa.eu',
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
_robots_txt_url = 'https://www.ema.europa.eu/robots.txt'
|
| 318 |
+
|
| 319 |
+
def _extract_search_results_count(self, soup: BeautifulSoup) -> int | None:
|
| 320 |
+
"""Extract the number of search results."""
|
| 321 |
+
span = soup.find("span", class_="source-summary-count")
|
| 322 |
+
if span is None:
|
| 323 |
+
self.logger.warning('no search results count found')
|
| 324 |
+
return None
|
| 325 |
+
return int(span.text.strip('()'))
|
| 326 |
+
|
| 327 |
+
def _extract_number_of_pages(self, soup: BeautifulSoup) -> int | None:
|
| 328 |
+
"""Extract the number of pages from the search results."""
|
| 329 |
+
nav = soup.find('nav', class_='pager')
|
| 330 |
+
if nav is not None:
|
| 331 |
+
a = nav.find('a', {'class': 'page-link', 'aria-label': 'Last'})
|
| 332 |
+
if a and a.has_attr('href'):
|
| 333 |
+
href = a['href']
|
| 334 |
+
match = re.search(r'&page=(\d+)', href)
|
| 335 |
+
if match:
|
| 336 |
+
return int(match.group(1)) + 1
|
| 337 |
+
self.logger.warning('no pager found')
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
@staticmethod
|
| 341 |
+
def _extract_search_item_divs(soup: BeautifulSoup) -> List[Tag]:
|
| 342 |
+
"""Extract the list of div elements contining the different search results."""
|
| 343 |
+
parent = soup.find('div', class_=['row', 'row-cols-1'])
|
| 344 |
+
return parent.find_all('div', class_='col')
|
| 345 |
+
|
| 346 |
+
async def _ema_document_search(self, term: str, max_docs_src: int) -> SearchResults:
|
| 347 |
+
"""Search the EMA database for PDF documents with a given term."""
|
| 348 |
+
|
| 349 |
+
# Get initial search results
|
| 350 |
+
url = self._pdf_search_template.format(term=term)
|
| 351 |
+
self.logger.debug(f'search ema database with url={url}')
|
| 352 |
+
content = await self._aiohttp_get_html(url=url, headers=self._headers)
|
| 353 |
+
soup = BeautifulSoup(content, 'html.parser')
|
| 354 |
+
|
| 355 |
+
# Get the number of search results and number of pages
|
| 356 |
+
count = self._extract_search_results_count(soup=soup)
|
| 357 |
+
n_pages = self._extract_number_of_pages(soup=soup)
|
| 358 |
+
|
| 359 |
+
# Extract the divs containing the search results
|
| 360 |
+
items = []
|
| 361 |
+
if count is not None and count > 0:
|
| 362 |
+
# Extract items from page=0
|
| 363 |
+
items_from_page = self._extract_search_item_divs(soup=soup)
|
| 364 |
+
items.extend(items_from_page)
|
| 365 |
+
|
| 366 |
+
# Extract items from page=1, 2, ...
|
| 367 |
+
if n_pages is not None and n_pages > 1:
|
| 368 |
+
for i in range(1, n_pages):
|
| 369 |
+
url = f'{url}&page={i}'
|
| 370 |
+
content = await self._aiohttp_get_html(url=url, headers=self._headers)
|
| 371 |
+
soup = BeautifulSoup(content, 'html.parser')
|
| 372 |
+
|
| 373 |
+
items_from_page = self._extract_search_item_divs(soup=soup)
|
| 374 |
+
items.extend(items_from_page)
|
| 375 |
+
|
| 376 |
+
if len(items) >= max_docs_src:
|
| 377 |
+
self.logger.debug(f'found #items={len(items)} in #pages={i+1}')
|
| 378 |
+
break
|
| 379 |
+
|
| 380 |
+
# Check for extraction mismatch
|
| 381 |
+
if len(items) != count:
|
| 382 |
+
self.logger.warning(f'mismatch #items={len(items)} and the total count={count}')
|
| 383 |
+
|
| 384 |
+
self.logger.debug(f'extracted #items={len(items)} in #pages={n_pages}')
|
| 385 |
+
return SearchResults(count=count, n_pages=n_pages, items=items)
|
| 386 |
+
|
| 387 |
+
@staticmethod
|
| 388 |
+
def _extract_title(tag: Tag) -> str | None:
|
| 389 |
+
"""Extract the title of the document."""
|
| 390 |
+
title = tag.find('p', class_='file-title')
|
| 391 |
+
return title.text if title is not None else None
|
| 392 |
+
|
| 393 |
+
def _extract_url(self, tag: Tag) -> str | None:
|
| 394 |
+
"""Extract the links href to the relevant PDF document."""
|
| 395 |
+
link = tag.find('a', href=True)
|
| 396 |
+
if link is None:
|
| 397 |
+
self.logger.warning('no link found')
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
href = link['href']
|
| 401 |
+
url = f'{self._source_url}{href}' if href.startswith('/') else href
|
| 402 |
+
if not url.endswith('.pdf'):
|
| 403 |
+
self.logger.warning(f'url={url} does not point to a PDF document')
|
| 404 |
+
return None
|
| 405 |
+
return url
|
| 406 |
+
|
| 407 |
+
async def _extract_text(self, url: str) -> str:
|
| 408 |
+
"""Extract the text from the PDF document."""
|
| 409 |
+
async with aiohttp.ClientSession(headers=self._headers) as session:
|
| 410 |
+
async with session.get(url) as response:
|
| 411 |
+
response.raise_for_status()
|
| 412 |
+
content = await response.read() # Read the entire content
|
| 413 |
+
|
| 414 |
+
# Create a BytesIO object from the content
|
| 415 |
+
pdf_stream = BytesIO(content)
|
| 416 |
+
|
| 417 |
+
# Open the PDF with PyMuPDF using the BytesIO object
|
| 418 |
+
doc = pymupdf.open(stream=pdf_stream, filetype="pdf")
|
| 419 |
+
|
| 420 |
+
# Extract text from all pages
|
| 421 |
+
text = '\n'.join([page.get_text() for page in doc])
|
| 422 |
+
|
| 423 |
+
# Close the document
|
| 424 |
+
doc.close()
|
| 425 |
+
return text
|
| 426 |
+
|
| 427 |
+
@staticmethod
|
| 428 |
+
def _extract_language(tag: Tag) -> str | None:
|
| 429 |
+
"""Extract the language of the document."""
|
| 430 |
+
lang_tag = tag.find('p', class_='language-meta')
|
| 431 |
+
if lang_tag is None:
|
| 432 |
+
return None
|
| 433 |
+
text = lang_tag.text
|
| 434 |
+
start = text.find('(')
|
| 435 |
+
stop = text.find(')')
|
| 436 |
+
if start != -1 and stop != -1:
|
| 437 |
+
return text[start+1:stop].lower()
|
| 438 |
+
return None
|
| 439 |
+
|
| 440 |
+
@staticmethod
|
| 441 |
+
def _extract_date(tag: Tag) -> datetime | None:
|
| 442 |
+
"""Extract the publication date of the document."""
|
| 443 |
+
time_tag = tag.find('time')
|
| 444 |
+
if time_tag and time_tag.has_attr('datetime'):
|
| 445 |
+
return datetime.fromisoformat(time_tag['datetime'])
|
| 446 |
+
return None
|
| 447 |
+
|
| 448 |
+
async def _parse_items(self, items: List[Tag]) -> List[Document]:
|
| 449 |
+
"""From a list of divs containing the search results, extract the relevant information and parse it into a list
|
| 450 |
+
of :class:`Document` objects.
|
| 451 |
+
"""
|
| 452 |
+
data = []
|
| 453 |
+
for i, tag in enumerate(items):
|
| 454 |
+
url = self._extract_url(tag=tag)
|
| 455 |
+
if url is None:
|
| 456 |
+
self.logger.debug(f'no url found for item {i}')
|
| 457 |
+
continue
|
| 458 |
+
|
| 459 |
+
# Extract the relevant information from the document
|
| 460 |
+
self.logger.debug(f'parsing document with url={url}')
|
| 461 |
+
title = self._extract_title(tag=tag)
|
| 462 |
+
text = await self._extract_text(url=url)
|
| 463 |
+
language = self._extract_language(tag=tag)
|
| 464 |
+
publication_date = self._extract_date(tag=tag)
|
| 465 |
+
|
| 466 |
+
# Split long texts into chunks
|
| 467 |
+
texts = self.split_text_into_chunks(text=text)
|
| 468 |
+
|
| 469 |
+
# Create the Document object(s)
|
| 470 |
+
for text in texts:
|
| 471 |
+
document = Document(
|
| 472 |
+
id_=f'{self._source_url} {title} {text} {language} {publication_date}',
|
| 473 |
+
text=text,
|
| 474 |
+
source=self._source,
|
| 475 |
+
title=title,
|
| 476 |
+
url=url,
|
| 477 |
+
source_url=self._source_url,
|
| 478 |
+
source_favicon_url=self._source_favicon_url,
|
| 479 |
+
language=language,
|
| 480 |
+
publication_date=publication_date,
|
| 481 |
+
)
|
| 482 |
+
data.append(document)
|
| 483 |
+
self.logger.debug(f'created #docs={len(data)}')
|
| 484 |
+
return data
|
| 485 |
+
|
| 486 |
+
async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
|
| 487 |
+
"""Query and retrieve all PRAC documents for the given search term.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
- args_: the arguments for the spock pipeline
|
| 491 |
+
- queue_out: the output queue for the scraped data
|
| 492 |
+
"""
|
| 493 |
+
term = args_.term
|
| 494 |
+
max_docs_src = args_.max_docs_src
|
| 495 |
+
self.logger.debug(f'starting scraping the source={self._source} with term={term}')
|
| 496 |
+
|
| 497 |
+
# Search for relevant documents with a given term
|
| 498 |
+
search_results = await self._ema_document_search(term=term, max_docs_src=max_docs_src)
|
| 499 |
+
n_items = len(search_results.items)
|
| 500 |
+
if n_items > max_docs_src:
|
| 501 |
+
self.logger.warning(f'from #items={n_items} only max_docs_src={max_docs_src} will be parsed')
|
| 502 |
+
items = search_results.items[:max_docs_src]
|
| 503 |
+
|
| 504 |
+
# Parse the documents
|
| 505 |
+
data = await self._parse_items(items=items)
|
| 506 |
+
n_data = len(data)
|
| 507 |
+
if n_data > max_docs_src:
|
| 508 |
+
self.logger.warning(f'the #items={n_items} were chunked into #documents={n_data} from where only max_docs_src={max_docs_src} will be added to the queue')
|
| 509 |
+
data = data[:max_docs_src]
|
| 510 |
+
|
| 511 |
+
# Add documents to the queue
|
| 512 |
+
for i, doc in enumerate(data):
|
| 513 |
+
id_ = f'{self._source}_{i}'
|
| 514 |
+
item = QueueItem(id_=id_, doc=doc)
|
| 515 |
+
await queue_out.put(item)
|
| 516 |
+
|
| 517 |
+
self.logger.info(f'retrieved #docs={len(data)} in source={self._source} for term={term}')
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class MHRAScraper(Scraper):
|
| 521 |
+
"""Class for scraping data from the Medicines and Healthcare products Regulatory Agency.
|
| 522 |
+
|
| 523 |
+
The scraper the MHRAs **Drug Safety Update** search API for retrieving relevant documents. From the list of results
|
| 524 |
+
it creates a list of :class:`Document` objects.
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
_source = 'mhra'
|
| 528 |
+
_source_url = 'https://www.gov.uk/durg-safety-update'
|
| 529 |
+
_source_favicon_url = 'https://www.gov.uk/favicon.ico'
|
| 530 |
+
|
| 531 |
+
_search_template = 'https://www.gov.uk/drug-safety-update?keywords={term}'
|
| 532 |
+
_source_base_url = 'https://www.gov.uk'
|
| 533 |
+
_language = 'en'
|
| 534 |
+
|
| 535 |
+
def _extract_search_results_count(self, parent: Tag) -> int | None:
|
| 536 |
+
"""Extract the number of search results."""
|
| 537 |
+
div = parent.find('div', class_='result-info__header')
|
| 538 |
+
h2 = div.find('h2') if div else None
|
| 539 |
+
|
| 540 |
+
if h2 is None:
|
| 541 |
+
self.logger.warning('no search results count found')
|
| 542 |
+
return None
|
| 543 |
+
text = h2.get_text(strip=True)
|
| 544 |
+
count = int(re.search(r'\d+', text).group())
|
| 545 |
+
return count
|
| 546 |
+
|
| 547 |
+
@staticmethod
|
| 548 |
+
def _extract_search_item_divs(parent: Tag) -> List[Tag]:
|
| 549 |
+
"""Extract the divs containing the search results."""
|
| 550 |
+
return parent.find_all('li', class_='gem-c-document-list__item')
|
| 551 |
+
|
| 552 |
+
async def _mhra_document_search(self, term: str) -> SearchResults:
|
| 553 |
+
"""Search the MHRA database for documents with a given term."""
|
| 554 |
+
|
| 555 |
+
# Get search results and extract divs containing the search results
|
| 556 |
+
url = self._search_template.format(term=term)
|
| 557 |
+
self.logger.debug(f"search mhra's drug safety update database with url={url}")
|
| 558 |
+
content = await self._aiohttp_get_html(url=url)
|
| 559 |
+
soup = BeautifulSoup(content, 'html.parser')
|
| 560 |
+
parent = soup.find('div', class_=['govuk-grid-column-two-thirds', 'js-live-search-results-block', 'filtered-results'])
|
| 561 |
+
|
| 562 |
+
# Extract the number of search results
|
| 563 |
+
count = self._extract_search_results_count(parent=parent)
|
| 564 |
+
|
| 565 |
+
# Extract the divs containing the search results
|
| 566 |
+
items = []
|
| 567 |
+
if count is not None and count > 0:
|
| 568 |
+
items = self._extract_search_item_divs(parent=parent) # For a given search term, the site shows all the results without pagination.
|
| 569 |
+
|
| 570 |
+
# Check for extraction mismatch
|
| 571 |
+
if len(items) != count:
|
| 572 |
+
self.logger.warning(f'mismatch #items={len(items)} and the total count={count}')
|
| 573 |
+
|
| 574 |
+
self.logger.debug(f'found #items={len(items)}')
|
| 575 |
+
return SearchResults(count=count, items=items)
|
| 576 |
+
|
| 577 |
+
def _extract_url(self, link: Tag) -> str | None:
|
| 578 |
+
"""Extract the url to the document."""
|
| 579 |
+
href = link['href']
|
| 580 |
+
return f'{self._source_base_url}{href}' if href.startswith('/') else href
|
| 581 |
+
|
| 582 |
+
async def _extract_text(self, url: str) -> str:
|
| 583 |
+
"""Extract the text from the document."""
|
| 584 |
+
content = await self._aiohttp_get_html(url=url)
|
| 585 |
+
soup = BeautifulSoup(content, 'html.parser')
|
| 586 |
+
main = soup.find('main')
|
| 587 |
+
text = main.get_text()
|
| 588 |
+
|
| 589 |
+
# Clean the spaces and newlines
|
| 590 |
+
text = re.sub(r'(\n\s*){3,}', '\n\n', text)
|
| 591 |
+
text = re.sub(r'^ +', '', text, flags=re.MULTILINE)
|
| 592 |
+
return text
|
| 593 |
+
|
| 594 |
+
@staticmethod
|
| 595 |
+
def _extract_date(tag: Tag) -> datetime | None:
|
| 596 |
+
"""Extract the publication date of the document."""
|
| 597 |
+
time_tag = tag.find('time')
|
| 598 |
+
if time_tag and time_tag.has_attr('datetime'):
|
| 599 |
+
return datetime.fromisoformat(time_tag['datetime'])
|
| 600 |
+
return None
|
| 601 |
+
|
| 602 |
+
async def _parse_items(self, items: List[Tag]) -> List[Document]:
|
| 603 |
+
"""From a list of divs containing the search results (items), extract the relevant information and parse it into a list
|
| 604 |
+
of :class:`Document` objects.
|
| 605 |
+
"""
|
| 606 |
+
data = []
|
| 607 |
+
for i, tag in enumerate(items):
|
| 608 |
+
link = tag.find('a', href=True)
|
| 609 |
+
if link is None:
|
| 610 |
+
self.logger.warning(f'no link found for item {i}')
|
| 611 |
+
continue
|
| 612 |
+
url = self._extract_url(link=link)
|
| 613 |
+
|
| 614 |
+
# Extract the relevant information from the document
|
| 615 |
+
self.logger.debug(f'parsing item with url={url}')
|
| 616 |
+
title = link.get_text(strip=True)
|
| 617 |
+
text = await self._extract_text(url=url)
|
| 618 |
+
publication_date = self._extract_date(tag=tag)
|
| 619 |
+
|
| 620 |
+
# Split long texts into chunks
|
| 621 |
+
texts = self.split_text_into_chunks(text=text)
|
| 622 |
+
|
| 623 |
+
# Create the Document object(s)
|
| 624 |
+
for text in texts:
|
| 625 |
+
document = Document(
|
| 626 |
+
id_=f'{self._source_url} {title} {text} {publication_date}',
|
| 627 |
+
text=text,
|
| 628 |
+
source=self._source,
|
| 629 |
+
title=title,
|
| 630 |
+
url=url,
|
| 631 |
+
source_url=self._source_url,
|
| 632 |
+
source_favicon_url=self._source_favicon_url,
|
| 633 |
+
language=self._language,
|
| 634 |
+
publication_date=publication_date,
|
| 635 |
+
)
|
| 636 |
+
data.append(document)
|
| 637 |
+
self.logger.debug(f'created #docs={len(data)}')
|
| 638 |
+
return data
|
| 639 |
+
|
| 640 |
+
async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
|
| 641 |
+
"""Query and retrieve all drug safety updates for the given search term.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
- args_: the arguments for the spock pipeline
|
| 645 |
+
- queue_out: the output queue for the scraped data
|
| 646 |
+
"""
|
| 647 |
+
term = args_.term
|
| 648 |
+
max_docs_src = args_.max_docs_src
|
| 649 |
+
self.logger.debug(f'starting scraping the source={self._source} with term={term}')
|
| 650 |
+
|
| 651 |
+
# Search for relevant documents with a given term
|
| 652 |
+
search_results = await self._mhra_document_search(term=term)
|
| 653 |
+
n_items = len(search_results.items)
|
| 654 |
+
if n_items > max_docs_src:
|
| 655 |
+
self.logger.warning(f'from #items={n_items} only max_docs_src={max_docs_src} will be parsed')
|
| 656 |
+
items = search_results.items[:max_docs_src]
|
| 657 |
+
|
| 658 |
+
# Parse the documents
|
| 659 |
+
data = await self._parse_items(items=items)
|
| 660 |
+
n_data = len(data)
|
| 661 |
+
if n_data > max_docs_src:
|
| 662 |
+
self.logger.warning(f'the #items={n_items} were chunked into #documents={n_data} from where only max_docs_src={max_docs_src} will be added to the queue')
|
| 663 |
+
data = data[:max_docs_src]
|
| 664 |
+
|
| 665 |
+
# Add documents to the queue
|
| 666 |
+
for i, doc in enumerate(data):
|
| 667 |
+
id_ = f'{self._source}_{i}'
|
| 668 |
+
item = QueueItem(id_=id_, doc=doc)
|
| 669 |
+
await queue_out.put(item)
|
| 670 |
+
|
| 671 |
+
self.logger.info(f'retrieved #docs={len(data)} in source={self._source} for term={term}')
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
class FDAScraper(Scraper):
|
| 675 |
+
"""Class for scraping data from the Food and Drug Administration.
|
| 676 |
+
|
| 677 |
+
The scraper uses the the same API as the web interface fo the FDA to search for relevant documents. By default the search applies the following filter:
|
| 678 |
+
- sorting by highest relevance
|
| 679 |
+
- filter for results from the Center of Drug Evaluation and Research
|
| 680 |
+
- filter for English language
|
| 681 |
+
- filter for Drugs
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
_source = 'fda'
|
| 685 |
+
_source_url = 'https://www.fda.gov'
|
| 686 |
+
_source_favicon_url = 'https://www.fda.gov/favicon.ico'
|
| 687 |
+
|
| 688 |
+
_search_template = (
|
| 689 |
+
'https://www.fda.gov/search?s={term}'
|
| 690 |
+
'&items_per_page=10'
|
| 691 |
+
'&sort_bef_combine=rel_DESC' # Sort by relevance
|
| 692 |
+
'&f%5B0%5D=center%3A815' # Filter for the Center for Drug Evaluation and Research
|
| 693 |
+
'&f%5B1%5D=language%3A1404' # Filter for English language
|
| 694 |
+
'&f%5B2%5D=prod%3A2312' # Filter for the Drugs section
|
| 695 |
+
)
|
| 696 |
+
_language = 'en'
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def _extract_search_results_count(self, soup: BeautifulSoup) -> int | None:
|
| 700 |
+
"""Extract the number of search results from the search info section."""
|
| 701 |
+
parent = soup.find('div', class_='lcds-search-filters__info')
|
| 702 |
+
if parent is not None:
|
| 703 |
+
div = parent.find('div', class_='view-header')
|
| 704 |
+
match = re.search(r'of (\d+) entr[y|ies]', div.text)
|
| 705 |
+
if match:
|
| 706 |
+
return int(match.group(1))
|
| 707 |
+
self.logger.warning('no search info found')
|
| 708 |
+
return None
|
| 709 |
+
|
| 710 |
+
def _extract_number_of_pages(self, soup: BeautifulSoup) -> int | None:
|
| 711 |
+
"""Extract the number of pages from the search results."""
|
| 712 |
+
nav = soup.find('nav', class_=['pager-nav', 'text-center'])
|
| 713 |
+
if nav is not None:
|
| 714 |
+
last_page = nav.find('li', class_=['pager__item', 'pager__item--last'])
|
| 715 |
+
a = last_page.find('a')
|
| 716 |
+
if a and a.has_attr('href'):
|
| 717 |
+
href = a['href']
|
| 718 |
+
match = re.search(r'&page=(\d+)', href)
|
| 719 |
+
if match:
|
| 720 |
+
return int(match.group(1)) + 1
|
| 721 |
+
self.logger.warning('no pager found')
|
| 722 |
+
return None
|
| 723 |
+
|
| 724 |
+
@staticmethod
|
| 725 |
+
def _extract_search_item_divs(soup: BeautifulSoup) -> List[Tag]:
|
| 726 |
+
"""Extract the divs containing the search results."""
|
| 727 |
+
parent = soup.find('div', class_='view-content')
|
| 728 |
+
return parent.find_all('div', recursive=False)
|
| 729 |
+
|
| 730 |
+
async def _fda_document_search(self, term: str, max_docs_src: int) -> SearchResults:
|
| 731 |
+
"""Search the FDA database for documents with a given term."""
|
| 732 |
+
|
| 733 |
+
# Get search results
|
| 734 |
+
url = self._search_template.format(term=term)
|
| 735 |
+
self.logger.debug(f'search fda database with url={url}')
|
| 736 |
+
content = await self._aiohttp_get_html(url=url)
|
| 737 |
+
soup = BeautifulSoup(content, 'html.parser')
|
| 738 |
+
|
| 739 |
+
# Get the number of search results and the number of pages
|
| 740 |
+
count = self._extract_search_results_count(soup=soup)
|
| 741 |
+
n_pages = self._extract_number_of_pages(soup=soup)
|
| 742 |
+
|
| 743 |
+
# Extract the divs containing the search results
|
| 744 |
+
items = []
|
| 745 |
+
if count is not None and count > 0:
|
| 746 |
+
# Extract items from page=0
|
| 747 |
+
items_from_page = self._extract_search_item_divs(soup=soup)
|
| 748 |
+
items.extend(items_from_page)
|
| 749 |
+
|
| 750 |
+
# Extract items from page=1, 2, ...
|
| 751 |
+
if n_pages is not None and n_pages > 1:
|
| 752 |
+
for i in range(1, n_pages):
|
| 753 |
+
url = f'{url}&page={i}'
|
| 754 |
+
content = await self._aiohttp_get_html(url=url)
|
| 755 |
+
soup = BeautifulSoup(content, 'html.parser')
|
| 756 |
+
|
| 757 |
+
items_from_page = self._extract_search_item_divs(soup=soup)
|
| 758 |
+
items.extend(items_from_page)
|
| 759 |
+
|
| 760 |
+
if len(items) >= max_docs_src:
|
| 761 |
+
self.logger.debug(f'found #items={len(items)} in #pages={i+1}')
|
| 762 |
+
break
|
| 763 |
+
|
| 764 |
+
# Check for extraction mismatch
|
| 765 |
+
if len(items) != count:
|
| 766 |
+
self.logger.warning(f'mismatch #items={len(items)} and the total count={count}')
|
| 767 |
+
|
| 768 |
+
self.logger.debug(f'extracted #items={len(items)} in #pages={n_pages}')
|
| 769 |
+
return SearchResults(count=count, n_pages=n_pages, items=items)
|
| 770 |
+
|
| 771 |
+
@staticmethod
|
| 772 |
+
def _extract_title(main: Tag) -> str | None:
|
| 773 |
+
"""Extract the title of the document."""
|
| 774 |
+
header = main.find('header', class_=['row', 'content-header'])
|
| 775 |
+
if header is not None:
|
| 776 |
+
h1 = header.find('h1', class_=['content-title', 'text-center'])
|
| 777 |
+
if h1 is not None:
|
| 778 |
+
return h1.text
|
| 779 |
+
return None
|
| 780 |
+
|
| 781 |
+
@staticmethod
|
| 782 |
+
def _extract_text(main: Tag) -> str | None:
|
| 783 |
+
"""Extract the text from the document."""
|
| 784 |
+
body = main.find('div', attrs={'class': 'col-md-8 col-md-push-2', 'role': 'main'})
|
| 785 |
+
return body.get_text() if body is not None else None
|
| 786 |
+
|
| 787 |
+
@staticmethod
|
| 788 |
+
def _extract_date(main: Tag) -> datetime | None:
|
| 789 |
+
"""Extract the publication date of the document."""
|
| 790 |
+
dl = main.find('dl', class_='lcds-description-list--grid')
|
| 791 |
+
if dl is not None:
|
| 792 |
+
dd = dl.find('dd', class_='cell-2_2')
|
| 793 |
+
time_tag = dd.find('time')
|
| 794 |
+
else:
|
| 795 |
+
time_tag = main.find('time')
|
| 796 |
+
if time_tag and time_tag.has_attr('datetime'):
|
| 797 |
+
return datetime.fromisoformat(time_tag['datetime'])
|
| 798 |
+
return None
|
| 799 |
+
|
| 800 |
+
async def _parse_item_page(self, url: str) -> List[Document]:
|
| 801 |
+
content = await self._aiohttp_get_html(url=url)
|
| 802 |
+
soup = BeautifulSoup(content, 'html.parser')
|
| 803 |
+
main = soup.find('main')
|
| 804 |
+
|
| 805 |
+
# Extract the relevant information from the document
|
| 806 |
+
title = self._extract_title(main=main)
|
| 807 |
+
text = self._extract_text(main=main)
|
| 808 |
+
publication_date = self._extract_date(main=main)
|
| 809 |
+
|
| 810 |
+
# Split long texts into chunks
|
| 811 |
+
texts = self.split_text_into_chunks(text=text)
|
| 812 |
+
|
| 813 |
+
# Create the Document object(s)
|
| 814 |
+
data = []
|
| 815 |
+
for text in texts:
|
| 816 |
+
document = Document(
|
| 817 |
+
id_=f'{self._source_url} {title} {text} {publication_date}',
|
| 818 |
+
text=text,
|
| 819 |
+
source=self._source,
|
| 820 |
+
title=title,
|
| 821 |
+
url=url,
|
| 822 |
+
source_url=self._source_url,
|
| 823 |
+
source_favicon_url=self._source_favicon_url,
|
| 824 |
+
language=self._language,
|
| 825 |
+
publication_date=publication_date,
|
| 826 |
+
)
|
| 827 |
+
data.append(document)
|
| 828 |
+
return data
|
| 829 |
+
|
| 830 |
+
async def _parse_items(self, items: List[Tag]) -> List[Document]:
|
| 831 |
+
"""From a list of divs containing the search results, extract the relevant information and parse it into a list
|
| 832 |
+
of :class:`Document` objects.
|
| 833 |
+
"""
|
| 834 |
+
data = []
|
| 835 |
+
for i, tag in enumerate(items):
|
| 836 |
+
link = tag.find('a', href=True)
|
| 837 |
+
if link is None:
|
| 838 |
+
self.logger.warning(f'no link found for item {i}')
|
| 839 |
+
continue
|
| 840 |
+
url = self._source_url + link['href']
|
| 841 |
+
|
| 842 |
+
# Create the Document object
|
| 843 |
+
self.logger.debug(f'parsing item with url={url}')
|
| 844 |
+
page_data = await self._parse_item_page(url=url)
|
| 845 |
+
data.extend(page_data)
|
| 846 |
+
self.logger.debug(f'created #docs={len(data)}')
|
| 847 |
+
return data
|
| 848 |
+
|
| 849 |
+
async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
|
| 850 |
+
term = args_.term
|
| 851 |
+
max_docs_src = args_.max_docs_src
|
| 852 |
+
self.logger.debug(f'starting scraping the source={self._source} with term={term}')
|
| 853 |
+
|
| 854 |
+
# Search for relevant documents with a given term
|
| 855 |
+
search_results = await self._fda_document_search(term=term, max_docs_src=max_docs_src)
|
| 856 |
+
n_items = len(search_results.items)
|
| 857 |
+
if n_items > max_docs_src:
|
| 858 |
+
self.logger.warning(f'from #items={n_items} only max_docs_src={max_docs_src} will be parsed')
|
| 859 |
+
items = search_results.items[:max_docs_src]
|
| 860 |
+
|
| 861 |
+
# Parse the documents
|
| 862 |
+
data = await self._parse_items(items=items)
|
| 863 |
+
n_data = len(data)
|
| 864 |
+
if n_data > max_docs_src:
|
| 865 |
+
self.logger.warning(f'the #items={n_items} were chunked into #documents={n_data} from where only max_docs_src={max_docs_src} will be added to the queue')
|
| 866 |
+
data = data[:max_docs_src]
|
| 867 |
+
|
| 868 |
+
# Add documents to the queue
|
| 869 |
+
for i, doc in enumerate(data):
|
| 870 |
+
id_ = f'{self._source}_{i}'
|
| 871 |
+
item = QueueItem(id_=id_, doc=doc)
|
| 872 |
+
await queue_out.put(item)
|
| 873 |
+
|
| 874 |
+
self.logger.info(f'retrieved #docs={len(data)} in source={self._source} for term={term}')
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
_SCRAPERS = [PubmedScraper, EMAScraper, MHRAScraper, FDAScraper]
|
| 878 |
+
_SOURCE_TO_SCRAPER = {src: scr for src, scr in zip(SCRAPING_SOURCES, _SCRAPERS)}
|
| 879 |
+
if not len(_SCRAPERS) == len(SCRAPING_SOURCES):
|
| 880 |
+
raise ValueError("number of scrapers and sources do not match")
|
| 881 |
+
|
| 882 |
+
async def _scraping(args_: Namespace, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> None:
|
| 883 |
+
"""Pop a source (str) from the input queue, perform the scraping task with the given term, and put the results in
|
| 884 |
+
the output queue until the input queue is empty.
|
| 885 |
+
|
| 886 |
+
Args:
|
| 887 |
+
- args_: the arguments for the spock pipeline
|
| 888 |
+
- queue_in: the input queue containing the sources to scrape
|
| 889 |
+
- queue_out: the output queue for the scraped data
|
| 890 |
+
"""
|
| 891 |
+
|
| 892 |
+
while True:
|
| 893 |
+
# Get source from input queue
|
| 894 |
+
source = await queue_in.get()
|
| 895 |
+
|
| 896 |
+
# Check stopping condition
|
| 897 |
+
if source is None:
|
| 898 |
+
queue_in.task_done()
|
| 899 |
+
break
|
| 900 |
+
|
| 901 |
+
# Get the scraper and apply it to the term
|
| 902 |
+
scraper = _SOURCE_TO_SCRAPER.get(source) # type: type[Scraper]
|
| 903 |
+
if scraper is None:
|
| 904 |
+
logger.error(f'unknown source={source}')
|
| 905 |
+
queue_in.task_done()
|
| 906 |
+
break
|
| 907 |
+
|
| 908 |
+
try:
|
| 909 |
+
await scraper().apply(args_=args_, queue_out=queue_out)
|
| 910 |
+
except Exception as e:
|
| 911 |
+
logger.error(f'error during scraping for source={source} and term={args_.term}: {e}')
|
| 912 |
+
queue_in.task_done()
|
| 913 |
+
continue
|
| 914 |
+
queue_in.task_done()
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def create_tasks(args_: Namespace, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> List[asyncio.Task]:
|
| 918 |
+
"""Create the asyncio scraping tasks."""
|
| 919 |
+
n_tasks = args_.n_scp_tasks
|
| 920 |
+
logger.info(f'setting up {n_tasks} scraping task(s) for source(s)={args_.source}')
|
| 921 |
+
tasks = [asyncio.create_task(_scraping(args_=args_, queue_in=queue_in, queue_out=queue_out)) for _ in range(n_tasks)]
|
| 922 |
+
return tasks
|