jaegglic commited on
Commit
c5e4363
·
1 Parent(s): 1be5cb1

Initial commit

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Spock
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  short_description: Spotting clinical knowledge
 
1
  ---
2
+ title: SpoCK
3
+ emoji: 🖖
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
  short_description: Spotting clinical knowledge
vianu/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+ from typing import Any, List
5
+
6
+ import gradio as gr
7
+ from gradio.events import Dependency
8
+
9
+ LOG_FMT = "%(asctime)s | %(name)s | %(funcName)s | %(levelname)s | %(message)s"
10
+
11
+ class BaseApp(ABC):
12
+ """The abstract base class of the main gradio application."""
13
+
14
+ def __init__(
15
+ self,
16
+ app_name: str | None = None,
17
+ favicon_path: Path | None = None,
18
+ allowed_paths: List[str] | None = None,
19
+ head_file: Path | None = None,
20
+ css_file: Path | None = None,
21
+ theme: gr.Theme | None = None,
22
+ local_state: Any | None = None,
23
+ session_state: Any | None = None,
24
+ ):
25
+ """
26
+ Args:
27
+ app_name: The name of the application. Defaults to None.
28
+ favicon_path: The favicon file as a :class:`pathlib.Path`. Defaults to None.
29
+ head_file: Custom html code as a :class:`pathlib.Path` to a html file. Defaults to None.
30
+ css_file (Path, optional): Custom css as a :class:`pathlib.Path` to a css file. Defaults to None.
31
+ theme (gr.Theme, optional): The theme of the application. Defaults to None.
32
+ local_state (Any, optional): The local state, where data persists in the browser's localStorage even after the page is refreshed or closed. Should be a json-serializable value (accessible only through it's serialized form). Defaults to None.
33
+ session_state (Any, optional): The session state, where data persists across multiple submits within a page session. Defaults to None
34
+ """
35
+ self.favicon_path = favicon_path
36
+ self.allowed_paths = allowed_paths
37
+
38
+ self._app_name = app_name
39
+ self._head_file = head_file
40
+ self._css_file = css_file
41
+ self._theme = theme
42
+
43
+ self._local_state = gr.BrowserState(local_state)
44
+ self._session_state = gr.State(session_state)
45
+
46
+
47
+ @abstractmethod
48
+ def setup_ui(self):
49
+ """Set up the user interface."""
50
+ pass
51
+
52
+ @abstractmethod
53
+ def register_events(self):
54
+ """Register the events."""
55
+ pass
56
+
57
+
58
+ def make(self) -> Dependency:
59
+ with gr.Blocks(
60
+ title=self._app_name,
61
+ head_paths=self._head_file,
62
+ css_paths=self._css_file,
63
+ theme=self._theme,
64
+ ) as demo:
65
+ self._local_state.render()
66
+ self._session_state.render()
67
+ self.setup_ui()
68
+ self.register_events()
69
+
70
+ demo.load()
71
+ return demo
vianu/spock/Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Set environment variables to prevent Python from writing .pyc files and to buffer stdout and stderr
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ # Set the working directory inside the container
8
+ WORKDIR /app
9
+ ENV PYTHONPATH=/app
10
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
11
+
12
+ # Copy the files into the working directory
13
+ COPY ../../vianu/__init__.py /app/vianu/__init__.py
14
+ COPY ../../vianu/spock /app/vianu/spock
15
+
16
+ # Install dependencies from requirements.txt
17
+ RUN pip install --upgrade pip \
18
+ && pip install -r vianu/spock/requirements.txt
19
+
20
+ # Expose the port your Gradio app will run on (default: 7860)
21
+ EXPOSE 7868
22
+
23
+ # Command to run the application
24
+ CMD ["python", "vianu/spock/launch_demo_app.py"]
vianu/spock/__init__.py ADDED
File without changes
vianu/spock/__main__.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import asyncio
3
+ from datetime import datetime
4
+ import logging
5
+ import sys
6
+ from typing import List, Tuple
7
+
8
+ from vianu import LOG_FMT
9
+ from vianu.spock.settings import SCRAPING_SOURCES, LOG_LEVEL
10
+ from vianu.spock.src.cli import parse_args
11
+ from vianu.spock.src.base import Setup, Document, SpoCK, FileHandler
12
+ from vianu.spock.src import scraping as scp
13
+ from vianu.spock.src import ner
14
+
15
+
16
+ logging.basicConfig(format=LOG_FMT, level=LOG_LEVEL)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ async def _orchestrator(
21
+ args_: Namespace,
22
+ src_queue: asyncio.Queue,
23
+ scp_queue: asyncio.Queue,
24
+ ner_queue: asyncio.Queue,
25
+ scp_tasks: List[asyncio.Task],
26
+ ner_tasks: List[asyncio.Task],
27
+ ) -> None:
28
+ """Orchestrates the scraping and NER tasks.
29
+
30
+ It waits for all scraping tasks to finish, then sends a sentinel to the scp_queue for each ner task (which will
31
+ trigger the ner tasks to finish -> cf :func:`vianu.spock.src.ner.apply`).
32
+ """
33
+ logger.debug('setting up orchestrator task')
34
+
35
+ # Insert sources into the source queue
36
+ sources = args_.source
37
+ for src in sources:
38
+ await src_queue.put(src)
39
+
40
+ # Insert sentinel for each scraping task
41
+ for _ in range(len(scp_tasks)):
42
+ await src_queue.put(None)
43
+
44
+ # Wait for all scraper tasks to finish and stop them
45
+ await src_queue.join()
46
+ try:
47
+ await asyncio.gather(*scp_tasks)
48
+ except asyncio.CancelledError:
49
+ logger.warning('scraping task(s) have previously been canceled')
50
+ except Exception as e:
51
+ logger.error(f'scraping task(s) failed with error: {e}')
52
+ raise e
53
+ for st in scp_tasks:
54
+ st.cancel()
55
+
56
+ # Insert sentinel for each NER
57
+ for _ in range(len(ner_tasks)):
58
+ await scp_queue.put(None)
59
+
60
+ # Wait for NER tasks to process all items and finish
61
+ await scp_queue.join()
62
+ try:
63
+ await asyncio.gather(*ner_tasks)
64
+ except asyncio.CancelledError:
65
+ logger.warning('ner task(s) have previously been canceled')
66
+ except Exception as e:
67
+ logger.error(f'ner task(s) failed with error: {e}')
68
+ raise e
69
+ for nt in ner_tasks:
70
+ nt.cancel()
71
+
72
+ # Insert sentinel into ner_queue to indicate end of processing
73
+ await ner_queue.put(None)
74
+
75
+
76
+ def setup_asyncio_framework(args_: Namespace) -> Tuple[asyncio.Queue, List[asyncio.Task], List[asyncio.Task], asyncio.Task]:
77
+ """Set up the asyncio framework for the SpoCK application."""
78
+ # Set up arguments
79
+ if args_.source is None:
80
+ args_.source = SCRAPING_SOURCES
81
+
82
+ # Set up queues
83
+ src_queue = asyncio.Queue()
84
+ scp_queue = asyncio.Queue()
85
+ ner_queue = asyncio.Queue()
86
+
87
+ # Start tasks
88
+ scp_tasks = scp.create_tasks(args_=args_, queue_in=src_queue, queue_out=scp_queue)
89
+ ner_tasks = ner.create_tasks(args_=args_, queue_in=scp_queue, queue_out=ner_queue)
90
+ orc_task = asyncio.create_task(
91
+ _orchestrator(
92
+ args_=args_,
93
+ src_queue=src_queue,
94
+ scp_queue=scp_queue,
95
+ ner_queue=ner_queue,
96
+ scp_tasks=scp_tasks,
97
+ ner_tasks=ner_tasks,
98
+ )
99
+ )
100
+ return ner_queue, scp_tasks, ner_tasks, orc_task
101
+
102
+
103
+ async def _collector(ner_queue: asyncio.Queue) -> List[Document]:
104
+ """Collect results from the NER queue."""
105
+ data = []
106
+ while True:
107
+ item = await ner_queue.get()
108
+
109
+ # Check stopping condition
110
+ if item is None:
111
+ ner_queue.task_done()
112
+ break
113
+
114
+ # Append document to data
115
+ data.append(item.doc)
116
+ ner_queue.task_done()
117
+
118
+ return data
119
+
120
+
121
+ async def main(args_: Namespace | None = None, save: bool = True) -> None:
122
+ """Main function for the SpoCK pipeline."""
123
+ started_at = datetime.now()
124
+ if args_ is None:
125
+ args_= parse_args(sys.argv[1:])
126
+
127
+ logging.basicConfig(level=args_.log_level.upper(), format=LOG_FMT)
128
+ logger.info(f'starting SpoCK (args_={args_})')
129
+
130
+ # Set up async structure (scraping queue/tasks, NER queue/tasks, orchestrator task)
131
+ ner_queue, _, _, _ = setup_asyncio_framework(args_)
132
+
133
+ # Set up collector task and wait for it to finish
134
+ # NOTE: if collector task is finished, the orchestrator is also finished (because of the sentinel in `ner_queue`)
135
+ # and therefore so are the scraping and NER tasks
136
+ col_task = asyncio.create_task(_collector(ner_queue))
137
+ data = await col_task
138
+ await ner_queue.join()
139
+
140
+ # Save data
141
+ if save:
142
+ file_name = args_.file_name
143
+ file_path = args_.file_path
144
+ spock = SpoCK(
145
+ id_=str(args_),
146
+ status='completed',
147
+ started_at=started_at,
148
+ finished_at=datetime.now(),
149
+ setup=Setup.from_namespace(args_),
150
+ data=data,
151
+ )
152
+ if file_name is not None and file_path is not None:
153
+ FileHandler(file_path=file_path).write(file_name=file_name, spock=spock)
154
+ logger.info('finished SpoCK')
155
+
156
+
157
+ if __name__ == '__main__':
158
+ asyncio.run(main())
vianu/spock/app/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from vianu.spock.app.app import App
2
+
3
+ __all__ = ['App']
vianu/spock/app/app.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from dataclasses import dataclass, field
3
+ from datetime import datetime
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Tuple
8
+
9
+ from dotenv import load_dotenv
10
+ import gradio as gr
11
+
12
+ from vianu.spock.settings import LOG_LEVEL, N_SCP_TASKS, N_NER_TASKS
13
+ from vianu.spock.settings import LARGE_LANGUAGE_MODELS, SCRAPING_SOURCES, MAX_DOCS_SRC
14
+ from vianu.spock.settings import GRADIO_APP_NAME, GRADIO_SERVER_PORT, GRADIO_MAX_JOBS, GRADIO_UPDATE_INTERVAL
15
+ from vianu.spock.settings import OLLAMA_BASE_URL_ENV_NAME, OPENAI_API_KEY_ENV_NAME
16
+ from vianu.spock.src.base import Setup, SpoCK, SpoCKList, QueueItem # noqa: F401
17
+ from vianu import BaseApp
18
+ from vianu.spock.__main__ import setup_asyncio_framework
19
+ import vianu.spock.app.formatter as fmt
20
+
21
+ logger = logging.getLogger(__name__)
22
+ load_dotenv()
23
+
24
+ # App settings
25
+ _ASSETS_PATH = Path(__file__).parents[1] / "assets"
26
+ _UI_SETTINGS_LLM_CHOICES = [(name, value) for name, value in zip(['Ollama', 'OpenAI'], LARGE_LANGUAGE_MODELS)]
27
+ _UI_SETTINGS_SOURCE_CHOICES = [(name, value) for name, value in zip(['PubMed', 'EMA', 'MHRA', 'FDA'], SCRAPING_SOURCES)]
28
+ if not len(_UI_SETTINGS_LLM_CHOICES) == len(LARGE_LANGUAGE_MODELS):
29
+ raise ValueError('LARGE_LANGUAGE_MODELS and _UI_SETTINGS_LLM_CHOICES must have the same length')
30
+ if not len(_UI_SETTINGS_SOURCE_CHOICES) == len(SCRAPING_SOURCES):
31
+ raise ValueError('SCRAPING_SOURCES and _UI_SETTINGS_SOURCE_CHOICES must have the same length')
32
+
33
+
34
+ @dataclass
35
+ class LocalState:
36
+ """The persistent local state."""
37
+ log_level: str = LOG_LEVEL
38
+ n_scp_tasks: int = N_SCP_TASKS
39
+ n_ner_tasks: int = N_NER_TASKS
40
+ max_jobs: int = GRADIO_MAX_JOBS
41
+ update_interval: float = GRADIO_UPDATE_INTERVAL
42
+
43
+
44
+ @dataclass
45
+ class SessionState:
46
+ """The session dependent state."""
47
+
48
+ # Asyncio setup
49
+ ner_queue: asyncio.Queue | None = None
50
+ scp_tasks: List[asyncio.Task] = field(default_factory=list)
51
+ ner_tasks: List[asyncio.Task] = field(default_factory=list)
52
+ orc_task: asyncio.Task | None = None
53
+ col_task: asyncio.Task | None = None
54
+
55
+ # Data
56
+ is_running: bool = False
57
+ spocks: SpoCKList = field(default_factory=list)
58
+ _index_running_spock: int | None = None
59
+ _index_active_spock: int | None = None
60
+
61
+
62
+ def set_running_spock(self, index: int | None) -> None:
63
+ self._index_running_spock = index
64
+
65
+ def get_running_spock(self) -> SpoCK:
66
+ return self.spocks[self._index_running_spock]
67
+
68
+ def set_active_spock(self, index: int | None) -> None:
69
+ self._index_active_spock = index
70
+
71
+ def get_active_spock(self) -> SpoCK:
72
+ return self.spocks[self._index_active_spock]
73
+
74
+
75
+ class App(BaseApp):
76
+ """The main gradio application."""
77
+
78
+ def __init__(self):
79
+ super().__init__(
80
+ app_name=GRADIO_APP_NAME,
81
+ favicon_path=_ASSETS_PATH / "images" / "favicon.png",
82
+ allowed_paths=[str(_ASSETS_PATH.resolve())],
83
+ head_file=_ASSETS_PATH / "head" / "scripts.html",
84
+ css_file=_ASSETS_PATH / "css" / "styles.css",
85
+ theme=gr.themes.Soft(),
86
+ local_state=LocalState(),
87
+ session_state=SessionState(),
88
+ )
89
+ self._components: Dict[str, Any] = {}
90
+
91
+ # --------------------------------------------------------------------------
92
+ # User Interface
93
+ # --------------------------------------------------------------------------
94
+ @staticmethod
95
+ def _ui_top_row():
96
+ with gr.Row(elem_classes="top-row"):
97
+ with gr.Column(scale=1):
98
+ gr.Image(
99
+ value=_ASSETS_PATH / "images" / "spock_logo_circular.png",
100
+ show_label=False,
101
+ elem_classes="image",
102
+ )
103
+ with gr.Column(scale=5):
104
+ value = """<div class='top-row title-desc'>
105
+ <div class='top-row title-desc title'>SpoCK: Spotting Clinical Knowledge</div>
106
+ <div class='top-row title-desc desc'><em>A tool for identifying <b>medicinal products</b> and <b>adverse drug reactions</b> inside publicly available literature</em></div>
107
+ </div>
108
+ """
109
+ gr.Markdown(value=value)
110
+
111
+ def _ui_corpus_settings(self):
112
+ """Settings column."""
113
+ with gr.Column(scale=1):
114
+ with gr.Accordion(label='LLM Endpoint'):
115
+ self._components['settings.llm_radio'] = gr.Radio(
116
+ label='Model', show_label=False, choices=_UI_SETTINGS_LLM_CHOICES, value='llama', interactive=True
117
+ )
118
+
119
+ # 'llama' specific settings
120
+ with gr.Group(visible=True) as self._components['settings.ollama_group']:
121
+ value = os.environ.get("OLLAMA_BASE_URL")
122
+ placeholder = 'base_url of ollama endpoint' if value is None else None
123
+ gr.Markdown('---')
124
+ self._components['settings.ollama_base_url'] = gr.Textbox(
125
+ label='base_url',
126
+ show_label=False,
127
+ info='base_url',
128
+ placeholder=placeholder,
129
+ value=value,
130
+ interactive=True,
131
+ )
132
+
133
+ # 'openai' specific settings
134
+ with gr.Group(visible=False) as self._components['settings.openai_group']:
135
+ value = os.environ.get("OPENAI_API_KEY")
136
+ placeholder = 'api_key of openai endpoint' if value is None else None
137
+ logger.debug(f'openai api_key={value}')
138
+ gr.Markdown('---')
139
+ self._components['settings.openai_api_key'] = gr.Textbox(
140
+ label='api_key',
141
+ show_label=False,
142
+ info='api_key',
143
+ placeholder=placeholder,
144
+ value=value,
145
+ interactive=True,
146
+ type='password',
147
+ )
148
+
149
+ with gr.Accordion(label='Sources', open=True):
150
+ self._components['settings.source'] = gr.CheckboxGroup(
151
+ label='Sources', show_label=False, choices=_UI_SETTINGS_SOURCE_CHOICES, value=SCRAPING_SOURCES, interactive=True
152
+ )
153
+ self._components['settings.max_docs_src'] = gr.Number(
154
+ label='max_docs_src', show_label=False, info='max. number of documents per source', value=MAX_DOCS_SRC, interactive=True
155
+ )
156
+
157
+ def _ui_corpus_row(self):
158
+ """Main corpus with settings, search field, job cards, and details"""
159
+ with gr.Row(elem_classes="bottom-container"):
160
+ self._ui_corpus_settings()
161
+ self._ui_corpus_main()
162
+
163
+ def _ui_corpus_main(self):
164
+ """Search field, job cards, and details."""
165
+ with gr.Column(scale=5):
166
+ # Search text field and start/stop/cancel buttons
167
+ with gr.Row(elem_classes="search-container"):
168
+ with gr.Column(scale=3):
169
+ self._components['main.search_term'] = gr.Textbox(
170
+ label="Search", show_label=False, placeholder="Enter your search term"
171
+ )
172
+
173
+ with gr.Column(scale=1, elem_classes='pipeline-button'):
174
+ self._components['main.start_button'] = gr.HTML('<div class="button-not-running">Start</div>', visible=True)
175
+ self._components['main.stop_button'] = gr.HTML('<div class="button-running">Stop</div>', visible=False)
176
+ self._components['main.cancel_button'] = gr.HTML('<div class="canceling">canceling...</div>', visible=False)
177
+
178
+ # Job summary cards
179
+ with gr.Row(elem_classes="jobs-container"):
180
+ self._components['main.cards'] = [gr.HTML('', elem_id=f'job-{i}', visible=False) for i in range(GRADIO_MAX_JOBS)]
181
+
182
+ # Details of the selected job
183
+ with gr.Row():
184
+ self._components['main.details'] = gr.HTML('<div class="details-container"></div>')
185
+
186
+ def setup_ui(self):
187
+ """Set up the user interface."""
188
+ self._ui_top_row()
189
+ self._ui_corpus_row()
190
+ self._components['timer'] = gr.Timer(value=GRADIO_UPDATE_INTERVAL, active=False, render=True)
191
+
192
+ # --------------------------------------------------------------------------
193
+ # Helpers
194
+ # --------------------------------------------------------------------------
195
+ @staticmethod
196
+ def _show_llm_settings(llm: str) -> Tuple[dict[str, Any], dict[str, Any]]:
197
+ logger.debug(f'show {llm} model settings')
198
+ if llm == 'llama':
199
+ return gr.update(visible=True), gr.update(visible=False)
200
+ elif llm == 'openai':
201
+ return gr.update(visible=False), gr.update(visible=True)
202
+ else:
203
+ return gr.update(visible=False), gr.update(visible=False)
204
+
205
+ @staticmethod
206
+ def _set_ollama_base_url(base_url: str) -> None:
207
+ """Setup ollama base_url as environment variable."""
208
+ logger.debug(f'set ollama base_url environment variable ({OLLAMA_BASE_URL_ENV_NAME}={base_url})')
209
+ os.environ[OLLAMA_BASE_URL_ENV_NAME] = base_url
210
+
211
+ @staticmethod
212
+ def _set_openai_api_key(api_key: str) -> None:
213
+ """Setup openai api_key as environment variable."""
214
+ log_key = '*****' if api_key else 'None'
215
+ logger.debug(f'set openai api key (api_key={log_key})')
216
+ os.environ[OPENAI_API_KEY_ENV_NAME] = api_key
217
+
218
+ @staticmethod
219
+ def _feed_cards_to_ui(local_state: dict, session_state: SessionState) -> List[dict[str, Any]]:
220
+ """For all existing SpoCKs, create and feed the job cards to the UI."""
221
+ spocks = session_state.spocks
222
+ logger.debug(f'feeding cards to UI (len(spocks)={len(spocks)})')
223
+
224
+ # Create the job cards for the existing spocks
225
+ cds = []
226
+ for i, spk in enumerate(spocks):
227
+ html = fmt.get_job_card_html(i, spk)
228
+ cds.append(gr.update(value=html, visible=True))
229
+
230
+ # Extdend with not-visible cards
231
+ # Note: for gradio >= 5.0.0 this logic could be replaces with dynamic number of gr.Blocks
232
+ # (see https://www.gradio.app/guides/dynamic-apps-with-render-decorator)
233
+ cds.extend([gr.update(visible=False) for _ in range(local_state['max_jobs'] - len(spocks))])
234
+ return cds
235
+
236
+ @staticmethod
237
+ def _feed_details_to_ui(session_state: SessionState) -> str:
238
+ """Collect the html texts for the documents of the selected job and feed them to the UI."""
239
+ if len(session_state.spocks) == 0:
240
+ return fmt.get_details_html([])
241
+
242
+ active_spock = session_state.get_active_spock()
243
+ logger.debug(f'feeding details to UI (len(data)={len(active_spock.data)})')
244
+ return fmt.get_details_html(active_spock.data)
245
+
246
+ @staticmethod
247
+ def _check_llm_settings(llm: str) -> None:
248
+ """Check if the LLM settings are set."""
249
+ if llm == 'llama':
250
+ if os.environ.get(OLLAMA_BASE_URL_ENV_NAME) is None:
251
+ raise gr.Error('Ollama base_url is not set (submit value with Enter)')
252
+ elif llm == 'openai':
253
+ if os.environ.get(OPENAI_API_KEY_ENV_NAME) is None:
254
+ raise gr.Error('OpenAI api_key is not set (submit value with Enter)')
255
+
256
+ @staticmethod
257
+ def _toggle_button(session_state: SessionState) -> Tuple[SessionState, dict[str, Any], dict[str, Any], dict[str, Any]]:
258
+ """Toggle the state of the pipleline between running <-> not_running.
259
+
260
+ As a result the corresponding buttons (Start, Stop, canceling...) are shown/hidden.
261
+ """
262
+ logger.debug(f'toggle button (is_running={session_state.is_running}->{not session_state.is_running})')
263
+ session_state.is_running = not session_state.is_running
264
+ if session_state.is_running:
265
+ # Show the stop button and hide the start/cancel button
266
+ return session_state, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
267
+ else:
268
+ # Show the start button and hide the stop/cancel button
269
+ return session_state, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
270
+
271
+ @staticmethod
272
+ def _show_cancel_button() -> Tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
273
+ """Shows the cancel button and hides the start and stop button."""
274
+ logger.debug('show cancel button')
275
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
276
+
277
+ @staticmethod
278
+ def _setup_spock(term: str, model: str, source: List[str], max_docs_src: int, local_state: dict, session_state: SessionState) -> SessionState:
279
+ """Setup a new SpoCK object."""
280
+ max_jobs = local_state['max_jobs']
281
+ spocks = session_state.spocks
282
+
283
+ # Check if the maximum number of jobs is reached and pop the last job if necessary
284
+ if len(spocks) >= max_jobs:
285
+ msg = f'max number of jobs ({max_jobs}); last job "{spocks[-1].setup.term}" is removed'
286
+ gr.Warning(msg)
287
+ logger.warning(msg)
288
+ spocks.pop(-1)
289
+
290
+ # Setup the running_spock and append it to the list of spocks
291
+ msg = f'started SpoCK for "{term}"'
292
+ gr.Info(msg)
293
+ logger.info(msg)
294
+
295
+ # Create new SpoCK object
296
+ setup = Setup(
297
+ id_=f'{term} {source} {model}',
298
+ term=term,
299
+ model=model,
300
+ source=source,
301
+ max_docs_src=max_docs_src,
302
+ log_level=local_state['log_level'],
303
+ n_scp_tasks=local_state['n_scp_tasks'],
304
+ n_ner_tasks=local_state['n_ner_tasks'],
305
+ )
306
+ spock = SpoCK(
307
+ id_=setup.id_,
308
+ status='running',
309
+ setup=setup,
310
+ started_at=setup.submission,
311
+ data=[]
312
+ )
313
+
314
+ # Set the running and focused SpoCK to be the new SpoCK object
315
+ index = 0
316
+ spocks.insert(index, spock)
317
+ session_state.set_running_spock(index=index)
318
+ session_state.set_active_spock(index=index)
319
+ return session_state
320
+
321
+ @staticmethod
322
+ async def _collector(session_state: SessionState) -> None:
323
+ """Append the processed document(s) from the `session_state.ner_queue` to the running spock."""
324
+ running_spock = session_state.get_running_spock()
325
+ ner_queue = session_state.ner_queue
326
+ logger.debug(f'starting collector (term={running_spock.setup.term})')
327
+
328
+ while True:
329
+ item = await ner_queue.get() # type: QueueItem
330
+ # Check stopping condition (added by the `orchestrator` in `vianu.spock.__main__`)
331
+ if item is None:
332
+ ner_queue.task_done()
333
+ break
334
+ running_spock.data.append(item.doc)
335
+ ner_queue.task_done()
336
+
337
+ async def _setup_asyncio_framework(self, session_state: SessionState) -> SessionState:
338
+ """"Start the SpoCK processes by setting up the asyncio framework and starting the asyncio tasks.
339
+
340
+ Main components of asyncio framework are:
341
+ - ner_queue: queue for collecting results from named entity recognition tasks
342
+ - scp_tasks: scraping tasks (cf. `vianu.spock.src.scp`)
343
+ - ner_tasks: named entity recognition tasks (cf. `vianu.spock.src.ner`)
344
+ - orc_task: orchestrating the process
345
+ - col_task: collect and assemble the final results
346
+ """
347
+ logger.info("setting up asyncio framework")
348
+
349
+ # Setup asyncio tasks as in `vianu.spock.__main__`
350
+ args_ = session_state.get_running_spock().setup.to_namespace()
351
+ ner_queue, scp_tasks, ner_tasks, orc_task = setup_asyncio_framework(args_=args_)
352
+ session_state.ner_queue = ner_queue
353
+ session_state.scp_tasks = scp_tasks
354
+ session_state.ner_tasks = ner_tasks
355
+ session_state.orc_task = orc_task
356
+
357
+ # Setup the app specific collection task
358
+ col_task = asyncio.create_task(self._collector(session_state=session_state))
359
+ session_state.col_task = col_task
360
+
361
+ return session_state
362
+
363
+ @staticmethod
364
+ async def _conclusion(session_state: SessionState) -> SessionState:
365
+ # Wait collector task to finish and join ner_queue
366
+ try:
367
+ await session_state.col_task
368
+ except asyncio.CancelledError:
369
+ logger.warning('collector task canceled')
370
+ return session_state # This stops the _conclusion step in the case the _canceling step was triggered
371
+ except Exception as e:
372
+ logger.error(f'collector task failed with error: {e}')
373
+ raise e
374
+ await session_state.ner_queue.join()
375
+
376
+ # Update the running_spock with the final data
377
+ running_spock = session_state.get_running_spock()
378
+ running_spock.status = 'completed'
379
+ running_spock.finished_at = datetime.now()
380
+
381
+ # Log the conclusion and update/empty the running_spock
382
+ gr.Info(f'job "{running_spock.setup.term}" finished')
383
+ logger.info(f'job "{running_spock.setup.term}" finished in {running_spock.runtime()}')
384
+
385
+ return session_state
386
+
387
+ @staticmethod
388
+ async def _canceling(session_state: SessionState) -> SessionState:
389
+ """Cancel all running :class:`asyncio.Task`."""
390
+ running_spock = session_state.get_running_spock()
391
+ gr.Info(f'canceled SpoCK for "{running_spock.setup.term}"')
392
+
393
+ # Update the running_spock
394
+ running_spock.status = 'stopped'
395
+ running_spock.finished_at = datetime.now()
396
+
397
+ # Cancel scraping tasks
398
+ logger.warning("canceling scraping tasks")
399
+ for task in session_state.scp_tasks:
400
+ task.cancel()
401
+ await asyncio.gather(*session_state.scp_tasks, return_exceptions=True)
402
+
403
+ # Cancel named entity recognition tasks
404
+ logger.warning("canceling named entity recognition tasks")
405
+ for task in session_state.ner_tasks:
406
+ task.cancel()
407
+ await asyncio.gather(*session_state.ner_tasks, return_exceptions=True)
408
+
409
+ # Cancel orchestrator task
410
+ logger.warning("canceling orchestrator task")
411
+ session_state.orc_task.cancel()
412
+ await asyncio.gather(session_state.orc_task, return_exceptions=True) # we use return_exceptions=True to avoid raising exceptions due to the subtasks being allready canceled`
413
+
414
+ # Cancel collector task
415
+ logger.warning("canceling collector task")
416
+ session_state.col_task.cancel()
417
+ await asyncio.gather(session_state.col_task, return_exceptions=True) # see remark above
418
+
419
+ return session_state
420
+
421
+ @staticmethod
422
+ def _change_active_spock_number(session_state: SessionState, index: int) -> SessionState:
423
+ logger.debug(f'card clicked={index}')
424
+ session_state.set_active_spock(index=index)
425
+ return session_state
426
+
427
+ # --------------------------------------------------------------------------
428
+ # Events
429
+ # --------------------------------------------------------------------------
430
+ def _event_timer(self):
431
+ self._components['timer'].tick(
432
+ fn=self._feed_cards_to_ui,
433
+ inputs=[self._local_state, self._session_state],
434
+ outputs=self._components['main.cards'],
435
+ ).then(
436
+ fn=self._feed_details_to_ui,
437
+ inputs=[self._session_state],
438
+ outputs=self._components['main.details'],
439
+ )
440
+
441
+ def _event_choose_llm(self):
442
+ """Choose LLM model show the correspoding settings."""
443
+ self._components['settings.llm_radio'].change(
444
+ fn=self._show_llm_settings,
445
+ inputs=self._components['settings.llm_radio'],
446
+ outputs=[
447
+ self._components['settings.ollama_group'],
448
+ self._components['settings.openai_group'],
449
+ ],
450
+ )
451
+
452
+ def _event_settings_ollama(self):
453
+ """Callback of the ollama settings."""
454
+ self._components['settings.ollama_base_url'].submit(
455
+ fn=self._set_ollama_base_url,
456
+ inputs=self._components['settings.ollama_base_url'],
457
+ )
458
+
459
+ def _event_settings_openai(self):
460
+ """Callback of the openai settings."""
461
+ self._components['settings.openai_api_key'].submit(
462
+ fn=self._set_openai_api_key,
463
+ inputs=self._components['settings.openai_api_key'],
464
+ )
465
+
466
+ def _event_start_spock(self) -> None:
467
+ search_term = self._components['main.search_term']
468
+ start_button = self._components['main.start_button']
469
+ timer = self._components['timer']
470
+
471
+ gr.on(
472
+ triggers=[search_term.submit, start_button.click],
473
+ fn=self._check_llm_settings,
474
+ inputs=self._components['settings.llm_radio'],
475
+ ).success(
476
+ fn=self._toggle_button,
477
+ inputs=self._session_state,
478
+ outputs=[
479
+ self._session_state,
480
+ self._components['main.start_button'],
481
+ self._components['main.stop_button'],
482
+ self._components['main.cancel_button'],
483
+ ]
484
+ ).then(
485
+ fn=self._setup_spock,
486
+ inputs=[
487
+ search_term,
488
+ self._components['settings.llm_radio'],
489
+ self._components['settings.source'],
490
+ self._components['settings.max_docs_src'],
491
+ self._local_state,
492
+ self._session_state,
493
+ ],
494
+ outputs=self._session_state
495
+ ).then(
496
+ fn=self._setup_asyncio_framework,
497
+ inputs=self._session_state,
498
+ outputs=self._session_state,
499
+ ).then(
500
+ fn=lambda: None, outputs=search_term # Empty the search term in the UI
501
+ ).then(
502
+ fn=self._feed_cards_to_ui,
503
+ inputs=[self._local_state, self._session_state],
504
+ outputs=self._components['main.cards'],
505
+ ).then(
506
+ fn=lambda: gr.update(active=True), outputs=timer
507
+ ).then(
508
+ fn=self._conclusion,
509
+ inputs=self._session_state,
510
+ outputs=self._session_state,
511
+ ).then(
512
+ fn=self._feed_cards_to_ui,
513
+ inputs=[self._local_state, self._session_state],
514
+ outputs=self._components['main.cards'],
515
+ ).then(
516
+ fn=self._feed_details_to_ui, # called one more time in order to enforce update of the details (regardless of the state of the timer)
517
+ inputs=[self._session_state],
518
+ outputs=self._components['main.details'],
519
+ ).then(
520
+ fn=lambda: gr.update(active=False), outputs=timer
521
+ ).then(
522
+ fn=self._toggle_button,
523
+ inputs=self._session_state,
524
+ outputs=[
525
+ self._session_state,
526
+ self._components['main.start_button'],
527
+ self._components['main.stop_button'],
528
+ self._components['main.cancel_button'],
529
+ ]
530
+ )
531
+
532
+ def _event_stop_spock(self):
533
+ # NOTE: when `stop_button.click` is triggered, the above pipeline (started by `search_term.submit` or
534
+ # `start_button.click`) is still running and awaiting the `_conclusion` step to finish. The `stop_button.click`
535
+ # event will cause the `_conclusion` step to terminate, after which the subsequent steps will still be executed;
536
+ # -> therefore, there is no need to add these steps here.
537
+ self._components['main.stop_button'].click(
538
+ fn=self._show_cancel_button,
539
+ outputs=[
540
+ self._components['main.start_button'],
541
+ self._components['main.stop_button'],
542
+ self._components['main.cancel_button'],
543
+ ]
544
+ ).then(
545
+ fn=self._canceling,
546
+ inputs=self._session_state,
547
+ outputs=self._session_state,
548
+ )
549
+
550
+ def _event_card_click(self):
551
+ for index, crd in enumerate(self._components['main.cards']):
552
+ crd.click(
553
+ fn=self._change_active_spock_number,
554
+ inputs=[self._session_state, gr.Number(value=index, visible=False)],
555
+ outputs=self._session_state,
556
+ ).then(
557
+ fn=self._feed_details_to_ui,
558
+ inputs=[self._session_state],
559
+ outputs=self._components['main.details'],
560
+ )
561
+
562
+ def register_events(self):
563
+ """Register the events."""
564
+ # Setup timer for feed cards and details
565
+ self._event_timer()
566
+
567
+ # Settings events
568
+ self._event_choose_llm()
569
+ self._event_settings_ollama()
570
+ self._event_settings_openai()
571
+
572
+ # Start/Stop events
573
+ self._event_start_spock()
574
+ self._event_stop_spock()
575
+
576
+ # Card click events for showing details
577
+ self._event_card_click()
578
+
579
+
580
+ if __name__ == "__main__":
581
+ from vianu.spock.app import App
582
+
583
+ app = App()
584
+ demo = app.make()
585
+ demo.queue().launch(
586
+ favicon_path=app.favicon_path,
587
+ inbrowser=True,
588
+ allowed_paths=[
589
+ str(_ASSETS_PATH.resolve()),
590
+ ],
591
+ server_port=GRADIO_SERVER_PORT,
592
+ )
vianu/spock/app/formatter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ from vianu.spock.src.base import Document, SpoCK
5
+ from vianu.spock.settings import DATE_FORMAT
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ JOBS_CONTAINER_CARD_TEMPLATE = """
11
+ <div class="card" onclick="cardClickHandler(this)">
12
+ <div class="title">{title} {status}</div>
13
+ <div class="info">Date: {date}</div>
14
+ <div class="info">Sources: {sources}</div>
15
+ <div class="info">#docs: {n_doc} | #adr: {n_adr}</div>
16
+ </div>
17
+ """
18
+
19
+ DETAILS_CONTAINER_TEMPLATE = """
20
+ <div id='details' class='details-container'>
21
+ <div class='items'>{items}</div>
22
+ </div>
23
+ """
24
+
25
+ DETAILS_CONTAINER_ITEM_TEMPLATE = """
26
+ <div class='item'>
27
+ <div class='top'>
28
+ <div class='favicon'><img src='{favicon}' alt='Favicon'></div>
29
+ <div class='title'><a href='{url}'>{title}</a></div>
30
+ </div>
31
+ <div class='bottom'>
32
+ {text}
33
+ </div>
34
+ </div>
35
+ """
36
+
37
+
38
+ def _get_details_html_items(data: List[Document]):
39
+ """Get the HTML items for the details container. Each item contains the favicon, title, and the text with the
40
+ highlighted named entities.
41
+ """
42
+ items = []
43
+ max_title_lenth = 120
44
+ for doc in data:
45
+ items.append(
46
+ DETAILS_CONTAINER_ITEM_TEMPLATE.format(
47
+ favicon=doc.source_favicon_url,
48
+ url=doc.url,
49
+ title=doc.title[:max_title_lenth]
50
+ + ("..." if len(doc.title) > max_title_lenth else ""),
51
+ text=doc.get_html(),
52
+ details="details",
53
+ )
54
+ )
55
+ return "\n".join(items)
56
+
57
+
58
+ def get_details_html(data: List[Document]):
59
+ """Get the stacked HTML items for each document."""
60
+ if len(data) == 0:
61
+ return "<div>no results available (yet)</div>"
62
+ sorted_data = sorted(data, key=lambda x: (len(x.adverse_reactions), len(x.medicinal_products)), reverse=True)
63
+ items = _get_details_html_items(data=sorted_data)
64
+ return DETAILS_CONTAINER_TEMPLATE.format(items=items)
65
+
66
+
67
+ def _get_status_html(status: str) -> str:
68
+ """Get the HTML for the status."""
69
+ if status == "running":
70
+ return f"<span class='running'>({status.upper()})</span>"
71
+ elif status == "completed":
72
+ return f"<span class='completed'>({status.upper()})</span>"
73
+ elif status == "stopped":
74
+ return f"<span class='stopped'>({status.upper()})</span>"
75
+ else:
76
+ logger.error(f"unknown status: {status.upper()})")
77
+ return '<span>(status unknown)</span>'
78
+
79
+ def get_job_card_html(card_nmbr: int, spock: SpoCK):
80
+ """Get the HTML for the job card."""
81
+ job = spock.setup
82
+ data = spock.data
83
+
84
+ title = spock.setup.term
85
+ status = _get_status_html(spock.status)
86
+ sources = ", ".join(job.source)
87
+ date = job.submission.strftime(DATE_FORMAT)
88
+ n_doc = len(data)
89
+ n_adr = sum([len(d.adverse_reactions) for d in data])
90
+ return JOBS_CONTAINER_CARD_TEMPLATE.format(
91
+ nmbr=card_nmbr,
92
+ title=title,
93
+ status=status,
94
+ date=date,
95
+ sources=sources,
96
+ n_doc=n_doc,
97
+ n_adr=n_adr,
98
+ )
vianu/spock/assets/css/styles.css ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* no footer */
2
+ footer {
3
+ display: none !important;
4
+ }
5
+
6
+ /* customize scrollbar in gr.Dataframe */
7
+ ::-webkit-scrollbar {
8
+ background: var(--background-fill-primary);
9
+ }
10
+ ::-webkit-scrollbar-thumb {
11
+ background-color: var(--border-color-primary);
12
+ border: 4px solid transparent;
13
+ border-radius: 100px;
14
+ background-clip: content-box;
15
+ }
16
+ ::-webkit-scrollbar-corner {
17
+ background: var(--background-fill-primary);
18
+ }
19
+
20
+ .top-row {
21
+ display: flex;
22
+ align-items: center;
23
+ }
24
+
25
+ .top-row .image {
26
+ display: flex;
27
+ justify-content: center;
28
+ align-items: center;
29
+ background-color: transparent !important;
30
+ border: none !important;
31
+ box-shadow: none !important;
32
+ padding: 0 !important;
33
+ height: 150px;
34
+ }
35
+
36
+ .top-row .title-desc {
37
+ justify-content: center;
38
+ display: block;
39
+ height: 100%;
40
+ background: var(--block-title-background-fill);
41
+ border-radius: var(--radius-lg);
42
+ font: var(--font);
43
+ text-align: center;
44
+ padding: var(--scale-0)
45
+ }
46
+
47
+ .top-row .title-desc .title {
48
+ color: var(--block-title-text-color);
49
+ font-size: var(--text-xxl);
50
+ font-weight: var(--weight-extrabold);
51
+ }
52
+
53
+ .top-row .title-desc .desc {
54
+ padding: 0px !important;
55
+ }
56
+
57
+ .search-container {
58
+ display: flex;
59
+ align-items: flex-end;
60
+ }
61
+
62
+
63
+ .search-container .button-not-running,
64
+ .search-container .button-running,
65
+ .search-container .canceling {
66
+ cursor: pointer;
67
+ display: flex;
68
+ justify-content: center;
69
+ align-items: center;
70
+ color: var(--color-grey-100);
71
+ border-radius: var(--radius-md);
72
+ font-size: var(--text-xxl);
73
+ font-weight: var(--weight-extrabold);
74
+ }
75
+
76
+ .search-container .button-not-running {
77
+ background: linear-gradient(to right, var(--color-pink-500), var(--button-primary-background-fill));
78
+ }
79
+ .search-container .button-running {
80
+ background: var(--color-red-600);
81
+ }
82
+ .search-container .canceling {
83
+ background: var(--color-grey-500);
84
+ }
85
+
86
+ .jobs-container {
87
+ display: grid;
88
+ grid-template-columns: repeat(auto-fill, minmax(var(--size-80), 1fr));
89
+ }
90
+ .card {
91
+ justify-self: center;
92
+ cursor: pointer;
93
+ border: none var(--spacing-md) var(--block-title-border-color);
94
+ border-radius: var(--radius-lg);
95
+ width: var(--size-80);
96
+ padding: var(--scale-2);
97
+ font: var(--font);
98
+ background: var(--block-label-background-fill);
99
+ }
100
+ .card:active {
101
+ transform: scale(0.97);
102
+ }
103
+ .card .title {
104
+ font-weight: var(--weight-bold);
105
+ font-size: var(--text-lg);
106
+ color: var(--block-title-text-color);
107
+ margin-bottom: var(--scale-0);
108
+ }
109
+ .card .title .running{
110
+ color: var(--color-red-600);
111
+ font-size: var(--text-md);
112
+ }
113
+ .card .title .stopped{
114
+ color: var(--color-grey-500);
115
+ font-size: var(--text-md);
116
+ }
117
+ .card .title .completed{
118
+ color: var(--color-green-700);
119
+ font-size: var(--text-md);
120
+ }
121
+ .card .info {
122
+ color: var(--block-info-text-color);
123
+ margin-bottom: var(--block-label-margin);
124
+ }
125
+
126
+ .details-container {
127
+ display: flex;
128
+ border-radius: 8px;
129
+ }
130
+ .details-container .title {
131
+ font-weight: var(--weight-bold);
132
+ font-size: var(--text-xl);
133
+ padding-left: var(--scale-0);
134
+ margin: var(--block-label-margin) 0 var(--block-label-margin) 0;
135
+ }
136
+ .details-container .info {
137
+ color: var(--block-info-text-color);
138
+ padding-left: var(--scale-0);
139
+ margin: var(--block-label-margin) 0 var(--block-label-margin) 0;
140
+ }
141
+
142
+ .details-container .items {
143
+ margin-top: var(--scale-4);
144
+ }
145
+
146
+ .details-container .items .item {
147
+ padding: var(--scale-0);
148
+ margin-bottom: var(--scale-0);
149
+ border: solid var(--block-label-border-width) var(--block-label-border-color);
150
+ }
151
+
152
+ .details-container .items .item .top{
153
+ display: flex;
154
+ align-items: center;
155
+ }
156
+
157
+ .details-container .items .item .top .favicon{
158
+ display: flex;
159
+ align-items: center;
160
+ margin-right: var(--scale-0);
161
+ }
162
+
163
+ .details-container .items .item .top .favicon img{
164
+ height: 1.5em;
165
+ }
166
+
167
+ .details-container .items .item .top .title{
168
+ display: flex;
169
+ align-items: center;
170
+ }
171
+
172
+ .details-container .items .item .bottom{
173
+ display: flex;
174
+ align-items: center;
175
+ }
176
+
177
+ .details-container .items .item .bottom .ner.mp,
178
+ .details-container .items .item .bottom .ner.adr {
179
+ padding: 0px 5px;
180
+ border-radius: 4px;
181
+ font-weight: var(--weight-bold);
182
+ }
183
+
184
+ .details-container .items .item .bottom .ner.mp {
185
+ color: var(--color-grey-100);
186
+ background-color: var(--color-purple-700);
187
+ }
188
+
189
+ .details-container .items .item .bottom .ner.adr {
190
+ background-color: var(--color-pink);
191
+ }
vianu/spock/assets/head/scripts.html ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script>
2
+ function cardClickHandler(card) {
3
+ var cards = document.getElementsByClassName("card");
4
+ for (let c of cards) {
5
+ c.style.borderStyle = "none";
6
+ }
7
+
8
+ card.style.borderStyle = "solid";
9
+ }
10
+
11
+ </script>
vianu/spock/assets/images/favicon.png ADDED
vianu/spock/assets/images/spock_logo.png ADDED
vianu/spock/assets/images/spock_logo_circular.png ADDED
vianu/spock/launch_demo_app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from vianu import LOG_FMT
5
+ from vianu.spock.settings import LOG_LEVEL, GRADIO_SERVER_PORT
6
+ from vianu.spock.app import App
7
+
8
+ logging.basicConfig(level=LOG_LEVEL.upper(), format=LOG_FMT)
9
+ os.environ["GRADIO_SERVER_PORT"] = str(GRADIO_SERVER_PORT)
10
+
11
+ if __name__ == "__main__":
12
+ app = App()
13
+ demo = app.make()
14
+ demo.queue().launch(
15
+ favicon_path=app.favicon_path,
16
+ inbrowser=True,
17
+ allowed_paths=app.allowed_paths,
18
+ )
vianu/spock/launch_demo_pipeline.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import asyncio
3
+
4
+ from vianu.spock.__main__ import main
5
+ from vianu.spock.settings import SCRAPING_SOURCES, MAX_DOCS_SRC
6
+ from vianu.spock.settings import N_SCP_TASKS, N_NER_TASKS
7
+
8
+
9
+ _ARGS = {
10
+ 'term': 'ibuprofen',
11
+ 'max_docs_src': MAX_DOCS_SRC,
12
+ 'source': SCRAPING_SOURCES,
13
+ 'model': 'llama',
14
+ 'n_scp_tasks': N_SCP_TASKS,
15
+ 'n_ner_tasks': N_NER_TASKS,
16
+ 'log_level': 'DEBUG',
17
+ }
18
+
19
+
20
+ if __name__ == '__main__':
21
+ asyncio.run(main(args_=Namespace(**_ARGS), save=False))
vianu/spock/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.11.11
2
+ beautifulsoup4==4.12.3
3
+ dacite==1.8.1
4
+ gradio==5.10.0
5
+ numpy==2.2.1
6
+ pymupdf==1.25.1
7
+ python-dotenv==1.0.1
8
+ openai==1.59.5
9
+ defusedxml==0.7.1
vianu/spock/settings.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General settings
2
+ LOG_LEVEL = 'DEBUG'
3
+ N_CHAR_DOC_ID = 12
4
+ FILE_PATH = "/tmp/spock/" # nosec
5
+ FILE_NAME = "spock"
6
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
7
+
8
+ # Gradio app settings
9
+ GRADIO_APP_NAME = "SpoCK"
10
+ GRADIO_SERVER_PORT=7868
11
+ GRADIO_MAX_JOBS = 5
12
+ GRADIO_UPDATE_INTERVAL = 2
13
+
14
+ # Scraping settings
15
+ SCRAPING_SOURCES = ['pubmed', 'ema', 'mhra', 'fda']
16
+ MAX_CHUNK_SIZE = 500
17
+ MAX_DOCS_SRC = 5
18
+ N_SCP_TASKS = 2
19
+
20
+ PUBMED_ESEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
21
+ PUBMED_EFETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
22
+ PUBMED_DB = 'pubmed'
23
+ PUBMED_BATCH_SIZE = 20
24
+
25
+ # NER settings
26
+ LARGE_LANGUAGE_MODELS = ['llama', 'openai']
27
+ MAX_TOKENS = 128.000
28
+ LLAMA_MODEL='llama3.2'
29
+ OPENAI_MODEL='gpt-4o'
30
+ N_NER_TASKS = 2
31
+ OLLAMA_BASE_URL_ENV_NAME = "OLLAMA_BASE_URL"
32
+ OPENAI_API_KEY_ENV_NAME = "OPENAI_API_KEY"
vianu/spock/src/__init__.py ADDED
File without changes
vianu/spock/src/base.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from argparse import Namespace
3
+ from dataclasses import dataclass, asdict, field
4
+ from datetime import datetime, timedelta
5
+ from hashlib import sha256
6
+ import json
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import dacite
12
+ import numpy as np
13
+ from typing import List, Self
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ @dataclass
19
+ class Serializable:
20
+ """Abstract base class for all dataclasses that can be serialized to a dictionary."""
21
+ def to_dict(self) -> dict:
22
+ """Converts the object to a dictionary."""
23
+ return asdict(self)
24
+
25
+ @classmethod
26
+ def from_dict(cls, dict_: dict) -> Self:
27
+ """Creates an object from a dictionary."""
28
+ return dacite.from_dict(
29
+ data_class=cls,
30
+ data=dict_,
31
+ config=dacite.Config(type_hooks={datetime: datetime.fromisoformat})
32
+ )
33
+
34
+
35
+ @dataclass
36
+ class Identicator(ABC):
37
+ """Abstract base class for entities with customized id.
38
+
39
+ Notes
40
+ The identifier :param:`Identicator.id_` is hashed and enriched with `_id_prefix` if this is
41
+ not present. This means as long as the `id_` begins with `_id_prefix` nothing is done.
42
+
43
+ This behavior aims to allow:
44
+ SubIdenticator(id_='This is the string that identifies the entity')
45
+
46
+ and with _id_prefix='sub' it produces an id_ of the form:
47
+ id_ = 'sub_5d41402abc4b2a76b9719d911017c592'
48
+ """
49
+
50
+ id_: str
51
+
52
+ def __eq__(self, other: object) -> bool:
53
+ if not isinstance(other, Identicator):
54
+ return NotImplemented
55
+ return self.id_ == other.id_
56
+
57
+ def __post_init__(self):
58
+ if not self.id_.startswith(self._id_prefix):
59
+ self.id_ = self._id_prefix + self._hash_id_str()
60
+
61
+ def _hash_id_str(self):
62
+ return sha256(self.id_.encode()).hexdigest()
63
+
64
+ @property
65
+ @abstractmethod
66
+ def _id_prefix(self):
67
+ pass
68
+
69
+
70
+ @dataclass(eq=False)
71
+ class NamedEntity(Identicator, Serializable):
72
+ """Class for all named entities."""
73
+ text: str = field(default_factory=str)
74
+ class_: str = field(default_factory=str)
75
+ location: List[int] | None = None
76
+
77
+ @property
78
+ def _id_prefix(self):
79
+ return 'ent_'
80
+
81
+
82
+ @dataclass(eq=False)
83
+ class Document(Identicator, Serializable):
84
+ """Class containing any document related information."""
85
+
86
+ # mandatory document fields
87
+ text: str
88
+ source: str
89
+
90
+ # additional document fields
91
+ title: str | None = None
92
+ url: str | None = None
93
+ source_url: str | None = None
94
+ source_favicon_url: str | None = None
95
+ language: str | None = None
96
+ publication_date: datetime | None = None
97
+
98
+ # named entities
99
+ medicinal_products: List[NamedEntity] = field(default_factory=list)
100
+ adverse_reactions: List[NamedEntity] = field(default_factory=list)
101
+
102
+ # protected fields
103
+ _html: str | None = None
104
+ _html_hash: str | None = None
105
+
106
+ @property
107
+ def _id_prefix(self):
108
+ return 'doc_'
109
+
110
+ def remove_named_entity_by_id(self, id_: str) -> None:
111
+ """Removes a named entity from the document by a given `doc.id_`."""
112
+ self.medicinal_products = [ne for ne in self.medicinal_products if ne.id_ != id_]
113
+ self.adverse_reactions = [ne for ne in self.adverse_reactions if ne.id_ != id_]
114
+
115
+ def _get_html_hash(self) -> str:
116
+ """Creates a sha256 hash from the named entities' ids. If the sets of named entities have been modified, this
117
+ function will return a different hash.
118
+ """
119
+ ne_ids = [ne.id_ for ne in self.medicinal_products + self.adverse_reactions]
120
+ html_hash_str = ' '.join(ne_ids)
121
+ return sha256(html_hash_str.encode()).hexdigest()
122
+
123
+
124
+ def _get_html(self) -> str:
125
+ """Creates the HTML representation of the document with highlighted named entities."""
126
+ text = f"<div>{self.text}</div>"
127
+
128
+ # Highlight medicinal products accodring to the css class 'mp'
129
+ mp_template = "<span class='ner mp'>{text} | {class_}</span>"
130
+ for ne in self.medicinal_products:
131
+ text = text.replace(
132
+ ne.text, mp_template.format(text=ne.text, class_=ne.class_)
133
+ )
134
+
135
+ # Highlight adverse drug reactions accodring to the css class 'adr'
136
+ adr_template = "<span class='ner adr'>{text} | {class_}</span>"
137
+ for ne in self.adverse_reactions:
138
+ text = text.replace(
139
+ ne.text, adr_template.format(text=ne.text, class_=ne.class_)
140
+ )
141
+
142
+ return text
143
+
144
+
145
+ def get_html(self) -> str:
146
+ """Returns the HTML representation of the document with highlighted named entities. This function checks if
147
+ the set of named entities has been modified and updates the HTML representation if necessary."""
148
+ html_hash = self._get_html_hash()
149
+ if self._html is None or html_hash != self._html_hash:
150
+ self._html = self._get_html()
151
+ self._html_hash = html_hash
152
+ return self._html
153
+
154
+
155
+ @dataclass(eq=False)
156
+ class Setup(Identicator, Serializable):
157
+ """Class for the pipeline setup (closely related to the CLI arguments)."""
158
+
159
+ # generic options
160
+ log_level: str
161
+ max_docs_src: int
162
+
163
+ # scraping options
164
+ term: str
165
+ source: List[str]
166
+ n_scp_tasks: int
167
+
168
+ # NER options
169
+ model: str
170
+ n_ner_tasks: int
171
+
172
+ # optional fields
173
+ submission: datetime | None = None
174
+ file_path: str | None = None
175
+ file_name: str | None = None
176
+
177
+ def __post_init__(self) -> None:
178
+ super().__post_init__()
179
+ self.submission = datetime.now() if self.submission is None else self.submission
180
+
181
+ @property
182
+ def _id_prefix(self) -> str:
183
+ return 'stp_'
184
+
185
+ def to_namespace(self) -> Namespace:
186
+ """Converts the :class:`Setup` object to a :class:`argparse.Namespace` object."""
187
+ return Namespace(**asdict(self))
188
+
189
+ @classmethod
190
+ def from_namespace(cls, args_: Namespace) -> Self:
191
+ """Creates a :class:`Setup` object from a :class:`argparse.Namespace` object."""
192
+ args_dict = vars(args_)
193
+ return cls(id_=str(args_dict), **args_dict)
194
+
195
+
196
+ @dataclass
197
+ class QueueItem:
198
+ """Class for the :class:`asyncio.Queue` items"""
199
+ id_: str
200
+ doc: Document
201
+
202
+
203
+ @dataclass
204
+ class SpoCK(Identicator, Serializable):
205
+ """Main class for the SpoCK pipeline mainly containing the job definition and the resulting data."""
206
+ # Generic fields
207
+ status: str | None = None
208
+ started_at: datetime | None = None
209
+ finished_at: datetime | None = None
210
+
211
+ # Pipeline fields
212
+ setup: Setup | None = None
213
+ data: List[Document] = field(default_factory=list)
214
+
215
+ @property
216
+ def _id_prefix(self) -> str:
217
+ return 'spk_'
218
+
219
+ def runtime(self) -> timedelta | None:
220
+ if self.started_at is not None:
221
+ if self.finished_at is None:
222
+ return datetime.now() - self.started_at
223
+ return self.finished_at - self.started_at
224
+ return None
225
+
226
+
227
+ SpoCKList = List[SpoCK]
228
+
229
+
230
+ class JSONEncoder(json.JSONEncoder):
231
+ """Custom JSON encoder for the :class:`Document` class."""
232
+ def default(self, obj):
233
+ if isinstance(obj, datetime):
234
+ return obj.isoformat()
235
+ if isinstance(obj, Document):
236
+ return asdict(obj)
237
+ if isinstance(obj, np.float32):
238
+ return str(obj)
239
+ if isinstance(obj, np.int64):
240
+ return int(obj)
241
+ # Let the base class default method raise the TypeError
242
+ return json.JSONEncoder.default(self, obj)
243
+
244
+
245
+ class FileHandler:
246
+ """Reads from and write data to a JSON file under a given file path."""
247
+
248
+ _suffix = '.json'
249
+
250
+ def __init__(self, file_path: Path | str) -> None:
251
+ self._file_path = Path(file_path) if isinstance(file_path, str) else file_path
252
+ if not self._file_path.exists():
253
+ os.makedirs(self._file_path)
254
+
255
+
256
+ def read(self, filename: str) -> List[Document]:
257
+ """Reads the data from a JSON file and casts it into a list of :class:`Document` objects."""
258
+ filename = (self._file_path / filename).with_suffix(self._suffix)
259
+
260
+ logger.info('reading data from file {filename}')
261
+ with open(filename.with_suffix(self._suffix), 'r', encoding="utf-8") as dfile:
262
+ dict_ = json.load(dfile)
263
+
264
+ return SpoCK.from_dict(dict_=dict_)
265
+
266
+ def write(self, file_name: str, spock: SpoCK, add_dt: bool = True) -> None:
267
+ """Writes the data to a JSON file.
268
+
269
+ If `add_dt=True`, the filename is `{file_name}_%Y%m%d%H%M%S.json`.
270
+ """
271
+ if add_dt:
272
+ file_name = f'{file_name}_{datetime.now().strftime("%Y%m%d%H%M%S")}'
273
+ file_name = (self._file_path / file_name).with_suffix(self._suffix)
274
+
275
+ logger.info(f'writing data to file {file_name}')
276
+ with open(file_name, 'w', encoding="utf-8") as dfile:
277
+ json.dump(spock.to_dict(), dfile, cls=JSONEncoder)
vianu/spock/src/cli.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI for SpoCK
2
+ """
3
+
4
+ import argparse
5
+ from typing import Sequence
6
+
7
+ from vianu.spock.settings import LOG_LEVEL, FILE_NAME, FILE_PATH, MAX_DOCS_SRC
8
+ from vianu.spock.settings import SCRAPING_SOURCES, N_SCP_TASKS
9
+ from vianu.spock.settings import N_NER_TASKS, LARGE_LANGUAGE_MODELS
10
+
11
+
12
+ def parse_args(args_: Sequence) -> argparse.Namespace:
13
+ parser = argparse.ArgumentParser(description="SpoCK", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
14
+
15
+ # Add generic options
16
+ gen_gp = parser.add_argument_group('generic')
17
+ gen_gp.add_argument("--log-level", metavar='', type=str, default=LOG_LEVEL, help='log level')
18
+ gen_gp.add_argument("--file-path", metavar='', type=str, default=FILE_PATH, help='path for storing results')
19
+ gen_gp.add_argument("--file-name", metavar='', type=str, default=FILE_NAME, help='filename for storing results')
20
+ gen_gp.add_argument("--max-docs-src", metavar='', type=int, default=MAX_DOCS_SRC, help='maximum number of documents per source')
21
+
22
+ # Add scraping group
23
+ scp_gp = parser.add_argument_group('scraping')
24
+ scp_gp.add_argument('--source', '-s', type=str, action='append', choices=SCRAPING_SOURCES, help='data sources for scraping')
25
+ scp_gp.add_argument('--term', '-t', metavar='', type=str, help='search term')
26
+ scp_gp.add_argument('--n-scp-tasks', metavar='', type=int, default=N_SCP_TASKS, help='number of async scraping tasks')
27
+
28
+ # Add NER group
29
+ ner_gp = parser.add_argument_group('ner')
30
+ ner_gp.add_argument('--model', '-m', type=str, choices=LARGE_LANGUAGE_MODELS, default='llama', help='NER model')
31
+ ner_gp.add_argument('--n-ner-tasks', metavar='', type=int, default=N_NER_TASKS, help='number of async ner tasks')
32
+
33
+ return parser.parse_args(args_)
vianu/spock/src/ner.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import aiohttp
3
+ from argparse import Namespace
4
+ import asyncio
5
+ import logging
6
+ import os
7
+ import re
8
+ from typing import List
9
+
10
+ from dotenv import load_dotenv
11
+ from openai import AsyncOpenAI
12
+
13
+ from vianu.spock.src.base import NamedEntity, QueueItem # noqa: F401
14
+ from vianu.spock.settings import N_CHAR_DOC_ID, LLAMA_MODEL, OPENAI_MODEL
15
+
16
+ logger = logging.getLogger(__name__)
17
+ load_dotenv()
18
+
19
+ NAMED_ENTITY_PROMPT = """
20
+ You are an expert in Natural Language Processing. Your task is to identify named entities (NER) in a given text.
21
+ You will focus on the following entities: adverse drug reaction (entity type: ADR), medicinal product (entity type: MP).
22
+ Once you identified all named entities of the above types, you return them as a Python list of tuples of the form (text, type).
23
+ It is important to only provide the Python list as your output, without any additional explanations or text.
24
+ In addition, make sure that the named entity texts are exact copies of the original text segment
25
+
26
+ Example 1:
27
+ Input:
28
+ "The most commonly reported side effects of dafalgan include headache, nausea, and fatigue."
29
+
30
+ Output:
31
+ [("dafalgan", "MP"), ("headache", "ADR"), ("nausea", "ADR"), ("fatigue", "ADR")]
32
+
33
+ Example 2:
34
+ Input:
35
+ "Patients taking acetaminophen or naproxen have reported experiencing skin rash, dry mouth, and difficulty breathing after taking this medication. In rare cases, seizures have also been observed."
36
+
37
+ Output:
38
+ [("acetaminophen", "MP"), ("naproxen", "MP"), ("skin rash", "ADR"), ("dry mouth", "ADR"), ("difficulty breathing", "ADR"), ("seizures", "ADR")]
39
+
40
+ Example 3:
41
+ Input:
42
+ "There are reported side effects as dizziness, stomach upset, and in some instances, temporary memory loss. These are mainly observed after taking Amitiza (lubiprostone) or Trulance (plecanatide)."
43
+
44
+ Output:
45
+ [("dizziness", "ADR"), ("stomach upset", "ADR"), ("temporary memory loss", "ADR"), ("Amitiza (lubiprostone)", "MP"), ("Trulance (plecanatide)", "MP")]
46
+ """
47
+
48
+
49
+ class NER(ABC):
50
+
51
+ _named_entity_pattern = re.compile(r'\("([^"]+)",\s?"(MP|ADR)"\)')
52
+
53
+ def __init__(self):
54
+ self.logger = logging.getLogger(self.__class__.__name__)
55
+
56
+
57
+ @staticmethod
58
+ def _get_loc_of_subtext(text: str, subtext: str) -> List[int] | None:
59
+ """Get the location of a subtext in a text."""
60
+ pos = text.find(subtext)
61
+ if pos == -1:
62
+ return None
63
+ return [pos, pos + len(subtext)]
64
+
65
+
66
+ def _add_loc_for_named_entities(self, text: str, named_entities: List[NamedEntity]) -> None:
67
+ txt_low = text.lower()
68
+ for ne in named_entities:
69
+ ne_txt_low = ne.text.lower()
70
+ loc = self._get_loc_of_subtext(text=txt_low, subtext=ne_txt_low)
71
+
72
+ if loc is not None:
73
+ ne.location = loc
74
+ ne.text = text[loc[0]:loc[1]]
75
+ else:
76
+ self.logger.warning(f'could not find location for named entity "{ne.text}" of class "{ne.class_}"')
77
+
78
+
79
+ @staticmethod
80
+ def _get_messages(text: str) -> List[dict]:
81
+ text = f'Process the following input text: "{text}"'
82
+ return [
83
+ {
84
+ "role": "system",
85
+ "content": NAMED_ENTITY_PROMPT,
86
+ },
87
+ {
88
+ "role": "user",
89
+ "content": text,
90
+ },
91
+ ]
92
+
93
+
94
+ @abstractmethod
95
+ async def _get_ner_model_answer(self, text: str) -> str:
96
+ raise NotImplementedError('OpenAINER._get_ner_model_answer is not implemented yet')
97
+
98
+
99
+ async def apply(self, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> None:
100
+ """Apply NER to a text received from input queue and put the results in an output queue."""
101
+
102
+ while True:
103
+ # Get text from input queue
104
+ item = await queue_in.get() # type: QueueItem
105
+
106
+ # Check stopping condition
107
+ if item is None:
108
+ queue_in.task_done()
109
+ break
110
+
111
+ # Get the model response with named entities
112
+ id_ = item.id_
113
+ doc = item.doc
114
+ self.logger.debug(f'starting ner for item.id_={id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]})')
115
+ try:
116
+ text = doc.text
117
+ content = await self._get_ner_model_answer(text=text)
118
+ except Exception as e:
119
+ self.logger.error(f'error during ner for item.id_={item.id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]}): {e}')
120
+ queue_in.task_done()
121
+ continue
122
+
123
+ # Parse the model answer and remove duplicates
124
+ ne_list = re.findall(self._named_entity_pattern, content)
125
+ ne_list = list(set(ne_list))
126
+
127
+ # Create list of NamedEntity objects
128
+ named_entities = []
129
+ for ne in ne_list:
130
+ try:
131
+ txt, cls_ = ne
132
+ named_entities.append(NamedEntity(id_=f'{text} {txt} {cls_}', text=txt, class_=cls_))
133
+ except Exception as e:
134
+ self.logger.error(f'error during creation of `NamedEntity` using {ne}: {e}')
135
+
136
+ # Add locations to named entities and remove those without location
137
+ self._add_loc_for_named_entities(text=text, named_entities=named_entities)
138
+ named_entities = [ne for ne in named_entities if ne.location is not None]
139
+
140
+ # Assign named entities to the document
141
+ ne_mp = [ne for ne in named_entities if ne.class_ == 'MP']
142
+ ne_adr = [ne for ne in named_entities if ne.class_ == 'ADR']
143
+ self.logger.debug(f'found #mp={len(ne_mp)} and #adr={len(ne_adr)} for item.id_={id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]})')
144
+ doc.medicinal_products = ne_mp
145
+ doc.adverse_reactions = ne_adr
146
+
147
+ # Put the document in the output queue
148
+ await queue_out.put(item)
149
+ queue_in.task_done()
150
+ self.logger.info(f'finished NER task for item.id_={id_} (doc.id_={doc.id_[:N_CHAR_DOC_ID]})')
151
+
152
+
153
+ class OllamaNER(NER):
154
+
155
+ def __init__(self, base_url: str, model: str):
156
+ super().__init__()
157
+ self._base_url = base_url
158
+ self._model = model
159
+
160
+
161
+ def _get_http_data(self, text: str, stream: bool = False) -> dict:
162
+ messages = self._get_messages(text=text)
163
+ data = {
164
+ "model": self._model,
165
+ "messages": messages,
166
+ "stream": stream,
167
+ }
168
+ return data
169
+
170
+
171
+ async def _get_ner_model_answer(self, text: str) -> str:
172
+ data = self._get_http_data(text=text)
173
+ async with aiohttp.ClientSession() as session:
174
+ url = f'{self._base_url}/api/chat/'
175
+ async with session.post(url, json=data) as response:
176
+ response.raise_for_status()
177
+ resp_json = await response.json()
178
+ content = resp_json['message']['content']
179
+ return content
180
+
181
+
182
+ class OpenAINER(NER):
183
+
184
+ def __init__(self, api_key: str, model: str):
185
+ super().__init__()
186
+ self._model = model
187
+ self._client = AsyncOpenAI(api_key=api_key)
188
+
189
+
190
+ async def _get_ner_model_answer(self, text: str) -> str:
191
+ messages = self._get_messages(text=text)
192
+ chat_completion = await self._client.chat.completions.create(
193
+ messages=messages,
194
+ model=self._model,
195
+ )
196
+ return chat_completion.choices[0].message.content
197
+
198
+
199
+ def create_tasks(args_: Namespace, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> List[asyncio.Task]:
200
+ """Create asyncio NER tasks."""
201
+ n_ner_tasks = args_.n_ner_tasks
202
+ model = args_.model
203
+
204
+ if model == 'llama':
205
+ base_url = os.environ.get('OLLAMA_BASE_URL')
206
+ if base_url is None:
207
+ raise EnvironmentError("The ollama endpoint must be set by the OLLAMA_ENDPOINT environment variable")
208
+ ner = OllamaNER(base_url=base_url, model=LLAMA_MODEL)
209
+
210
+ elif model == 'openai':
211
+ api_key = os.environ.get('OPENAI_API_KEY')
212
+ if api_key is None:
213
+ raise EnvironmentError("The api_key for the OpenAI client must be set by the OPENAI_API_KEY environment variable")
214
+ ner = OpenAINER(api_key=api_key, model=OPENAI_MODEL)
215
+
216
+ else:
217
+ raise ValueError(f"unknown ner model '{args_.model}'")
218
+
219
+ logger.info(f'setting up {n_ner_tasks} NER task(s)')
220
+ tasks = [asyncio.create_task(ner.apply(queue_in=queue_in, queue_out=queue_out)) for _ in range(n_ner_tasks)]
221
+ return tasks
vianu/spock/src/scraping.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for scraping data from different sources.
2
+
3
+ The module contains three main classes:
4
+ - :class:`Scraper`: Abstract base class for scraping data from different sources
5
+ - :class:`PubmedScraper`: Class for scraping data from the PubMed database
6
+ - :class:`EMAScraper`: Class for scraping data from the European Medicines Agency
7
+ """
8
+
9
+ from abc import ABC, abstractmethod
10
+ from argparse import Namespace
11
+ import asyncio
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+ from io import BytesIO
15
+ import logging
16
+ import re
17
+ from typing import List
18
+ import xml.etree.ElementTree as ET # nosec
19
+
20
+ import aiohttp
21
+ from bs4 import BeautifulSoup
22
+ from bs4.element import Tag
23
+ import defusedxml.ElementTree as DET
24
+ import numpy as np
25
+ import pymupdf
26
+
27
+ from vianu.spock.src.base import Document, QueueItem # noqa: F401
28
+ from vianu.spock.settings import MAX_CHUNK_SIZE, SCRAPING_SOURCES
29
+ from vianu.spock.settings import PUBMED_ESEARCH_URL, PUBMED_DB, PUBMED_EFETCH_URL, PUBMED_BATCH_SIZE
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class Scraper(ABC):
35
+
36
+ def __init__(self):
37
+ self.logger = logging.getLogger(self.__class__.__name__)
38
+
39
+ @abstractmethod
40
+ async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
41
+ """Main function for scraping data from a source.
42
+
43
+ Args:
44
+ - args_: the arguments for the spock pipeline
45
+ - queue_out: the output queue for the scraped data
46
+ """
47
+ pass
48
+
49
+ @staticmethod
50
+ def split_text_into_chunks(text: str, chunk_size: int = MAX_CHUNK_SIZE, separator: str = ' ') -> List[str]:
51
+ """Split a text into chunks of a given max size."""
52
+ words = text.split(separator)
53
+ N = len(words)
54
+ s = min(chunk_size, N)
55
+ n = N // s
56
+ bnd = [round(i) for i in np.linspace(0, 1, n+1) * N]
57
+
58
+ chunks = [separator.join(words[start:stop]) for start, stop in zip(bnd[:-1], bnd[1:])]
59
+ return chunks
60
+
61
+ @staticmethod
62
+ async def _aiohttp_get_html(url: str, headers=None) -> str:
63
+ """Get the content of a given URL by an aiohttp GET request."""
64
+ async with aiohttp.ClientSession(headers=headers) as session:
65
+ async with session.get(url=url) as response:
66
+ response.raise_for_status()
67
+ text = await response.text()
68
+ return text
69
+
70
+
71
+ @dataclass
72
+ class PubmedEntrezHistoryParams:
73
+ """Class for optimizing Pubmed database retrieval for large numbers of documents.
74
+
75
+ An example can be found here:
76
+ https://www.ncbi.nlm.nih.gov/books/n/helpeutils/chapter3/#chapter3.Application_3_Retrieving_large
77
+ """
78
+ web: str
79
+ key: str
80
+ count: int
81
+
82
+
83
+ class PubmedScraper(Scraper):
84
+ """Class for scraping data from the PubMed database.
85
+
86
+ The scraper uses the Pubmed API to search for relevant documents. From the list of results it creates a list of
87
+ :class:`Document` objects by the following main steps:
88
+ - Extract all PubmedArticle elements from the search results (other types are ignored)
89
+ - Extract the AbstractText from the PubmedArticle (if there is no abstract, the document is ignored)
90
+ """
91
+
92
+ _source = 'pubmed'
93
+ _source_url = 'https://pubmed.ncbi.nlm.nih.gov/'
94
+ _source_favicon_url = 'https://www.ncbi.nlm.nih.gov/favicon.ico'
95
+
96
+ _robots_txt_url = 'https://www.ncbi.nlm.nih.gov/robots.txt'
97
+
98
+ @staticmethod
99
+ def _get_entrez_history_params(text: str) -> PubmedEntrezHistoryParams:
100
+ """Retrieving the entrez history parameters for optimized search when requesting large numbers of documents.
101
+ An example can be found here:
102
+ https://www.ncbi.nlm.nih.gov/books/n/helpeutils/chapter3/#chapter3.Application_3_Retrieving_large
103
+ """
104
+ web = re.search(r'<WebEnv>(\S+)</WebEnv>', text).group(1)
105
+ key = re.search(r'<QueryKey>(\d+)</QueryKey>', text).group(1)
106
+ count = int(re.search(r'<Count>(\d+)</Count>', text).group(1))
107
+ return PubmedEntrezHistoryParams(web=web, key=key, count=count)
108
+
109
+ async def _pubmed_esearch(self, term: str) -> str:
110
+ """Search the Pubmed database with a given term and POST the results to entrez history server."""
111
+ url = f'{PUBMED_ESEARCH_URL}?db={PUBMED_DB}&term={term}&usehistory=y'
112
+ self.logger.debug(f'search pubmed database with url={url}')
113
+ esearch = await self._aiohttp_get_html(url=url)
114
+ return esearch
115
+
116
+ async def _pubmed_efetch(self, params: PubmedEntrezHistoryParams, max_docs_src: int) -> List[str]:
117
+ """Retrieve the relevant documents from the entrez history server."""
118
+ # Reduce the number of documents to be retrieved for efficiency
119
+ N = min(max_docs_src, int(params.count))
120
+ if N < params.count:
121
+ self.logger.warning(f'from the total number of documents={params.count} only {N} will be retrieved')
122
+
123
+ # Iterate over the batches of documents (with fixed batch size)
124
+ batch_size = min(params.count, PUBMED_BATCH_SIZE)
125
+ self.logger.debug(f'fetch #docs={N} in {N // batch_size + 1} batch(es) of size <= {batch_size}')
126
+ batches = []
127
+ for retstart in range(0, N, batch_size):
128
+
129
+ # Prepare URL for retrieving next batch of documents but stop if the maximum number is reached
130
+ retmax = min(max_docs_src - len(batches)*batch_size, batch_size)
131
+ url = f'{PUBMED_EFETCH_URL}?db={PUBMED_DB}&WebEnv={params.web}&query_key={params.key}&retstart={retstart}&retmax={retmax}'
132
+ self.logger.debug(f'fetch documents with url={url}')
133
+
134
+ # Fetch the documents
135
+ efetch = await self._aiohttp_get_html(url=url)
136
+ batches.append(efetch)
137
+ return batches
138
+
139
+ def _extract_medline_citation(self, element: ET.Element) -> ET.Element | None:
140
+ """Extract the MedlineCitation element from a PubmedArticle element."""
141
+ # Find and extract the MedlineCitation element
142
+ citation = element.find('MedlineCitation')
143
+ if citation is None:
144
+ self.logger.warning('no "MedlineCitation" element found')
145
+ return None
146
+ return citation
147
+
148
+ @staticmethod
149
+ def _extract_pmid(element: ET.Element) -> str | None:
150
+ """Extract the PMID from a MedlineCitation element."""
151
+ pmid = element.find('PMID')
152
+ return pmid.text if pmid is not None else None
153
+
154
+ @staticmethod
155
+ def _extract_article(element: ET.Element) -> ET.Element | None:
156
+ """Extract the article element from a PubmedArticle element."""
157
+ # Find and extract the Article element
158
+ article = element.find('Article')
159
+ return article
160
+
161
+ @staticmethod
162
+ def _extract_title(article: ET.Element) -> str | None:
163
+ """Extract the title from an Article element."""
164
+ title = article.find('ArticleTitle')
165
+ return title.text if title is not None else None
166
+
167
+ @staticmethod
168
+ def _extract_abstract(article: ET.Element) -> str | None:
169
+ """Extract the abstract from an Article element."""
170
+ separator = '\n\n'
171
+ abstract = article.find('Abstract')
172
+ if abstract is not None:
173
+ abstract = separator.join([a.text for a in abstract.findall('AbstractText') if a.text is not None])
174
+ return abstract
175
+
176
+ @staticmethod
177
+ def _extract_language(article: ET.Element) -> str | None:
178
+ """Extract the language from an Article element."""
179
+ language = article.find('Language')
180
+ return language.text if language is not None else None
181
+
182
+ @staticmethod
183
+ def _extract_date(article: ET.Element) -> datetime | None:
184
+ """Extract the publication date from an Article element."""
185
+ date = article.find('ArticleDate')
186
+ if date is None:
187
+ return None
188
+ year = int(date.find('Year').text)
189
+ month = int(date.find('Month').text)
190
+ day = int(date.find('Day').text)
191
+ return datetime(year=year, month=month, day=day)
192
+
193
+ @staticmethod
194
+ def _extract_publication_types(article: ET.Element) -> List[str]:
195
+ """Extract the publication types from an Article element."""
196
+ return [t.text for t in article.find('PublicationTypeList').findall('PublicationType')]
197
+
198
+ def _parse_pubmed_articles(self, batches: List[str]) -> List[Document]:
199
+ """Parse batches of ET.Elements into a single list of Document objects"""
200
+ data = []
201
+ for ib, text in enumerate(batches):
202
+ pubmed_articles = DET.fromstring(text).findall('PubmedArticle')
203
+ self.logger.debug(f'found #articles={len(pubmed_articles)} in batch {ib}')
204
+ for ie, element in enumerate(pubmed_articles):
205
+ self.logger.debug(f'parsing PubmedArticle {ie} of batch {ib}')
206
+ # Extract MedlineCitation and its PMID from PubmedArticle
207
+ citation = self._extract_medline_citation(element=element)
208
+ if citation is None:
209
+ self.logger.debug(f'no citation found in PubmedArticle {ie} of batch {ib}')
210
+ continue
211
+ pmid = self._extract_pmid(element=citation)
212
+
213
+ # Extract the Article element from the PubmedArticle
214
+ article = self._extract_article(element=citation)
215
+ if article is None:
216
+ self.logger.debug(f'no article found in PubmedArticle {ie} of batch {ib}')
217
+ continue
218
+
219
+ # Extract the relevant information from the Article element
220
+ title = self._extract_title(article=article)
221
+ text = self._extract_abstract(article=article)
222
+ if text is None:
223
+ self.logger.debug(f'no abstract found in PubmedArticle {ie} of batch {ib}')
224
+ continue
225
+ language = self._extract_language(article=article)
226
+ publication_date = self._extract_date(article=article)
227
+
228
+ # Split long texts into chunks
229
+ texts = self.split_text_into_chunks(text=text)
230
+
231
+ # Create the Document object(s)
232
+ for txt in texts:
233
+ document = Document(
234
+ id_=f'{self._source_url} {title} {txt} {language} {publication_date}',
235
+ text=txt,
236
+ source=self._source,
237
+ title=title,
238
+ url=f'{self._source_url}{pmid}/',
239
+ source_url=self._source_url,
240
+ source_favicon_url=self._source_favicon_url,
241
+ language=language,
242
+ publication_date=publication_date,
243
+ )
244
+ data.append(document)
245
+ self.logger.debug(f'parsed #docs={len(data)} from #batches={len(batches)}')
246
+ return data
247
+
248
+
249
+ async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
250
+ """Query and retrieve all PubmedArticle Documents for the given search term.
251
+
252
+ The retrieval is using two main functionalities of the Pubmed API:
253
+ - ESearch: Identify the relevant documents and store them in the entrez history server
254
+ - EFetch: Retrieve the relevant documents from the entrez history server
255
+
256
+ Args:
257
+ - args_: the arguments for the spock pipeline
258
+ - queue_out: the output queue for the scraped data
259
+ """
260
+ term = args_.term
261
+ max_docs_src = args_.max_docs_src
262
+ self.logger.debug(f'starting scraping the source={self._source} with term={term}')
263
+
264
+ # Search for relevant documents with a given term
265
+ esearch = await self._pubmed_esearch(term=term)
266
+
267
+ # Retrieve relevant documents in batches
268
+ params = self._get_entrez_history_params(esearch)
269
+ batches = await self._pubmed_efetch(params=params, max_docs_src=max_docs_src)
270
+
271
+ # Parse documents from batches
272
+ documents = self._parse_pubmed_articles(batches=batches)
273
+ documents = documents[:max_docs_src]
274
+
275
+ # Add documents to the queue
276
+ for i, doc in enumerate(documents):
277
+ id_ = f'{self._source}_{i}'
278
+ item = QueueItem(id_=id_, doc=doc)
279
+ await queue_out.put(item)
280
+
281
+ self.logger.info(f'retrieved #docs={len(documents)} in source={self._source} for term={term}')
282
+
283
+
284
+ @dataclass
285
+ class SearchResults:
286
+ """Class for storing the search results from different databases."""
287
+ count: int | None
288
+ n_pages: int | None
289
+ items: List[Tag]
290
+
291
+
292
+ class EMAScraper(Scraper):
293
+ """Class for scraping data from the European Medicines Agency.
294
+
295
+ The scraper uses the same API as the web interface of the EMA to search for relevant documents. From the list of
296
+ results it creates a list of :class:`Document` objects by the following main steps:
297
+ - Search the EMA database (filter for PDF documents only)
298
+ - Extract the text from the PDF documents
299
+ - Use regex to find texts where adverse drug reactions (or similar) are mentioned
300
+ - Return the most recent :class:`Document` objects
301
+ """
302
+
303
+ _source = 'ema'
304
+ _source_url = 'https://www.ema.europa.eu'
305
+ _source_favicon_url = 'https://www.ema.europa.eu/themes/custom/ema_theme/favicon.ico'
306
+
307
+ _pdf_search_template = (
308
+ "https://www.ema.europa.eu/en/search?search_api_fulltext={term}"
309
+ "&f%5B0%5D=ema_search_custom_entity_bundle%3Adocument" # This part is added to only retrieve PDF documents
310
+ "&f%5B1%5D=ema_search_entity_is_document%3ADocument"
311
+ )
312
+ _headers = {
313
+ 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
314
+ 'Host': 'www.ema.europa.eu',
315
+ }
316
+
317
+ _robots_txt_url = 'https://www.ema.europa.eu/robots.txt'
318
+
319
+ def _extract_search_results_count(self, soup: BeautifulSoup) -> int | None:
320
+ """Extract the number of search results."""
321
+ span = soup.find("span", class_="source-summary-count")
322
+ if span is None:
323
+ self.logger.warning('no search results count found')
324
+ return None
325
+ return int(span.text.strip('()'))
326
+
327
+ def _extract_number_of_pages(self, soup: BeautifulSoup) -> int | None:
328
+ """Extract the number of pages from the search results."""
329
+ nav = soup.find('nav', class_='pager')
330
+ if nav is not None:
331
+ a = nav.find('a', {'class': 'page-link', 'aria-label': 'Last'})
332
+ if a and a.has_attr('href'):
333
+ href = a['href']
334
+ match = re.search(r'&page=(\d+)', href)
335
+ if match:
336
+ return int(match.group(1)) + 1
337
+ self.logger.warning('no pager found')
338
+ return None
339
+
340
+ @staticmethod
341
+ def _extract_search_item_divs(soup: BeautifulSoup) -> List[Tag]:
342
+ """Extract the list of div elements contining the different search results."""
343
+ parent = soup.find('div', class_=['row', 'row-cols-1'])
344
+ return parent.find_all('div', class_='col')
345
+
346
+ async def _ema_document_search(self, term: str, max_docs_src: int) -> SearchResults:
347
+ """Search the EMA database for PDF documents with a given term."""
348
+
349
+ # Get initial search results
350
+ url = self._pdf_search_template.format(term=term)
351
+ self.logger.debug(f'search ema database with url={url}')
352
+ content = await self._aiohttp_get_html(url=url, headers=self._headers)
353
+ soup = BeautifulSoup(content, 'html.parser')
354
+
355
+ # Get the number of search results and number of pages
356
+ count = self._extract_search_results_count(soup=soup)
357
+ n_pages = self._extract_number_of_pages(soup=soup)
358
+
359
+ # Extract the divs containing the search results
360
+ items = []
361
+ if count is not None and count > 0:
362
+ # Extract items from page=0
363
+ items_from_page = self._extract_search_item_divs(soup=soup)
364
+ items.extend(items_from_page)
365
+
366
+ # Extract items from page=1, 2, ...
367
+ if n_pages is not None and n_pages > 1:
368
+ for i in range(1, n_pages):
369
+ url = f'{url}&page={i}'
370
+ content = await self._aiohttp_get_html(url=url, headers=self._headers)
371
+ soup = BeautifulSoup(content, 'html.parser')
372
+
373
+ items_from_page = self._extract_search_item_divs(soup=soup)
374
+ items.extend(items_from_page)
375
+
376
+ if len(items) >= max_docs_src:
377
+ self.logger.debug(f'found #items={len(items)} in #pages={i+1}')
378
+ break
379
+
380
+ # Check for extraction mismatch
381
+ if len(items) != count:
382
+ self.logger.warning(f'mismatch #items={len(items)} and the total count={count}')
383
+
384
+ self.logger.debug(f'extracted #items={len(items)} in #pages={n_pages}')
385
+ return SearchResults(count=count, n_pages=n_pages, items=items)
386
+
387
+ @staticmethod
388
+ def _extract_title(tag: Tag) -> str | None:
389
+ """Extract the title of the document."""
390
+ title = tag.find('p', class_='file-title')
391
+ return title.text if title is not None else None
392
+
393
+ def _extract_url(self, tag: Tag) -> str | None:
394
+ """Extract the links href to the relevant PDF document."""
395
+ link = tag.find('a', href=True)
396
+ if link is None:
397
+ self.logger.warning('no link found')
398
+ return None
399
+
400
+ href = link['href']
401
+ url = f'{self._source_url}{href}' if href.startswith('/') else href
402
+ if not url.endswith('.pdf'):
403
+ self.logger.warning(f'url={url} does not point to a PDF document')
404
+ return None
405
+ return url
406
+
407
+ async def _extract_text(self, url: str) -> str:
408
+ """Extract the text from the PDF document."""
409
+ async with aiohttp.ClientSession(headers=self._headers) as session:
410
+ async with session.get(url) as response:
411
+ response.raise_for_status()
412
+ content = await response.read() # Read the entire content
413
+
414
+ # Create a BytesIO object from the content
415
+ pdf_stream = BytesIO(content)
416
+
417
+ # Open the PDF with PyMuPDF using the BytesIO object
418
+ doc = pymupdf.open(stream=pdf_stream, filetype="pdf")
419
+
420
+ # Extract text from all pages
421
+ text = '\n'.join([page.get_text() for page in doc])
422
+
423
+ # Close the document
424
+ doc.close()
425
+ return text
426
+
427
+ @staticmethod
428
+ def _extract_language(tag: Tag) -> str | None:
429
+ """Extract the language of the document."""
430
+ lang_tag = tag.find('p', class_='language-meta')
431
+ if lang_tag is None:
432
+ return None
433
+ text = lang_tag.text
434
+ start = text.find('(')
435
+ stop = text.find(')')
436
+ if start != -1 and stop != -1:
437
+ return text[start+1:stop].lower()
438
+ return None
439
+
440
+ @staticmethod
441
+ def _extract_date(tag: Tag) -> datetime | None:
442
+ """Extract the publication date of the document."""
443
+ time_tag = tag.find('time')
444
+ if time_tag and time_tag.has_attr('datetime'):
445
+ return datetime.fromisoformat(time_tag['datetime'])
446
+ return None
447
+
448
+ async def _parse_items(self, items: List[Tag]) -> List[Document]:
449
+ """From a list of divs containing the search results, extract the relevant information and parse it into a list
450
+ of :class:`Document` objects.
451
+ """
452
+ data = []
453
+ for i, tag in enumerate(items):
454
+ url = self._extract_url(tag=tag)
455
+ if url is None:
456
+ self.logger.debug(f'no url found for item {i}')
457
+ continue
458
+
459
+ # Extract the relevant information from the document
460
+ self.logger.debug(f'parsing document with url={url}')
461
+ title = self._extract_title(tag=tag)
462
+ text = await self._extract_text(url=url)
463
+ language = self._extract_language(tag=tag)
464
+ publication_date = self._extract_date(tag=tag)
465
+
466
+ # Split long texts into chunks
467
+ texts = self.split_text_into_chunks(text=text)
468
+
469
+ # Create the Document object(s)
470
+ for text in texts:
471
+ document = Document(
472
+ id_=f'{self._source_url} {title} {text} {language} {publication_date}',
473
+ text=text,
474
+ source=self._source,
475
+ title=title,
476
+ url=url,
477
+ source_url=self._source_url,
478
+ source_favicon_url=self._source_favicon_url,
479
+ language=language,
480
+ publication_date=publication_date,
481
+ )
482
+ data.append(document)
483
+ self.logger.debug(f'created #docs={len(data)}')
484
+ return data
485
+
486
+ async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
487
+ """Query and retrieve all PRAC documents for the given search term.
488
+
489
+ Args:
490
+ - args_: the arguments for the spock pipeline
491
+ - queue_out: the output queue for the scraped data
492
+ """
493
+ term = args_.term
494
+ max_docs_src = args_.max_docs_src
495
+ self.logger.debug(f'starting scraping the source={self._source} with term={term}')
496
+
497
+ # Search for relevant documents with a given term
498
+ search_results = await self._ema_document_search(term=term, max_docs_src=max_docs_src)
499
+ n_items = len(search_results.items)
500
+ if n_items > max_docs_src:
501
+ self.logger.warning(f'from #items={n_items} only max_docs_src={max_docs_src} will be parsed')
502
+ items = search_results.items[:max_docs_src]
503
+
504
+ # Parse the documents
505
+ data = await self._parse_items(items=items)
506
+ n_data = len(data)
507
+ if n_data > max_docs_src:
508
+ self.logger.warning(f'the #items={n_items} were chunked into #documents={n_data} from where only max_docs_src={max_docs_src} will be added to the queue')
509
+ data = data[:max_docs_src]
510
+
511
+ # Add documents to the queue
512
+ for i, doc in enumerate(data):
513
+ id_ = f'{self._source}_{i}'
514
+ item = QueueItem(id_=id_, doc=doc)
515
+ await queue_out.put(item)
516
+
517
+ self.logger.info(f'retrieved #docs={len(data)} in source={self._source} for term={term}')
518
+
519
+
520
+ class MHRAScraper(Scraper):
521
+ """Class for scraping data from the Medicines and Healthcare products Regulatory Agency.
522
+
523
+ The scraper the MHRAs **Drug Safety Update** search API for retrieving relevant documents. From the list of results
524
+ it creates a list of :class:`Document` objects.
525
+ """
526
+
527
+ _source = 'mhra'
528
+ _source_url = 'https://www.gov.uk/durg-safety-update'
529
+ _source_favicon_url = 'https://www.gov.uk/favicon.ico'
530
+
531
+ _search_template = 'https://www.gov.uk/drug-safety-update?keywords={term}'
532
+ _source_base_url = 'https://www.gov.uk'
533
+ _language = 'en'
534
+
535
+ def _extract_search_results_count(self, parent: Tag) -> int | None:
536
+ """Extract the number of search results."""
537
+ div = parent.find('div', class_='result-info__header')
538
+ h2 = div.find('h2') if div else None
539
+
540
+ if h2 is None:
541
+ self.logger.warning('no search results count found')
542
+ return None
543
+ text = h2.get_text(strip=True)
544
+ count = int(re.search(r'\d+', text).group())
545
+ return count
546
+
547
+ @staticmethod
548
+ def _extract_search_item_divs(parent: Tag) -> List[Tag]:
549
+ """Extract the divs containing the search results."""
550
+ return parent.find_all('li', class_='gem-c-document-list__item')
551
+
552
+ async def _mhra_document_search(self, term: str) -> SearchResults:
553
+ """Search the MHRA database for documents with a given term."""
554
+
555
+ # Get search results and extract divs containing the search results
556
+ url = self._search_template.format(term=term)
557
+ self.logger.debug(f"search mhra's drug safety update database with url={url}")
558
+ content = await self._aiohttp_get_html(url=url)
559
+ soup = BeautifulSoup(content, 'html.parser')
560
+ parent = soup.find('div', class_=['govuk-grid-column-two-thirds', 'js-live-search-results-block', 'filtered-results'])
561
+
562
+ # Extract the number of search results
563
+ count = self._extract_search_results_count(parent=parent)
564
+
565
+ # Extract the divs containing the search results
566
+ items = []
567
+ if count is not None and count > 0:
568
+ items = self._extract_search_item_divs(parent=parent) # For a given search term, the site shows all the results without pagination.
569
+
570
+ # Check for extraction mismatch
571
+ if len(items) != count:
572
+ self.logger.warning(f'mismatch #items={len(items)} and the total count={count}')
573
+
574
+ self.logger.debug(f'found #items={len(items)}')
575
+ return SearchResults(count=count, items=items)
576
+
577
+ def _extract_url(self, link: Tag) -> str | None:
578
+ """Extract the url to the document."""
579
+ href = link['href']
580
+ return f'{self._source_base_url}{href}' if href.startswith('/') else href
581
+
582
+ async def _extract_text(self, url: str) -> str:
583
+ """Extract the text from the document."""
584
+ content = await self._aiohttp_get_html(url=url)
585
+ soup = BeautifulSoup(content, 'html.parser')
586
+ main = soup.find('main')
587
+ text = main.get_text()
588
+
589
+ # Clean the spaces and newlines
590
+ text = re.sub(r'(\n\s*){3,}', '\n\n', text)
591
+ text = re.sub(r'^ +', '', text, flags=re.MULTILINE)
592
+ return text
593
+
594
+ @staticmethod
595
+ def _extract_date(tag: Tag) -> datetime | None:
596
+ """Extract the publication date of the document."""
597
+ time_tag = tag.find('time')
598
+ if time_tag and time_tag.has_attr('datetime'):
599
+ return datetime.fromisoformat(time_tag['datetime'])
600
+ return None
601
+
602
+ async def _parse_items(self, items: List[Tag]) -> List[Document]:
603
+ """From a list of divs containing the search results (items), extract the relevant information and parse it into a list
604
+ of :class:`Document` objects.
605
+ """
606
+ data = []
607
+ for i, tag in enumerate(items):
608
+ link = tag.find('a', href=True)
609
+ if link is None:
610
+ self.logger.warning(f'no link found for item {i}')
611
+ continue
612
+ url = self._extract_url(link=link)
613
+
614
+ # Extract the relevant information from the document
615
+ self.logger.debug(f'parsing item with url={url}')
616
+ title = link.get_text(strip=True)
617
+ text = await self._extract_text(url=url)
618
+ publication_date = self._extract_date(tag=tag)
619
+
620
+ # Split long texts into chunks
621
+ texts = self.split_text_into_chunks(text=text)
622
+
623
+ # Create the Document object(s)
624
+ for text in texts:
625
+ document = Document(
626
+ id_=f'{self._source_url} {title} {text} {publication_date}',
627
+ text=text,
628
+ source=self._source,
629
+ title=title,
630
+ url=url,
631
+ source_url=self._source_url,
632
+ source_favicon_url=self._source_favicon_url,
633
+ language=self._language,
634
+ publication_date=publication_date,
635
+ )
636
+ data.append(document)
637
+ self.logger.debug(f'created #docs={len(data)}')
638
+ return data
639
+
640
+ async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
641
+ """Query and retrieve all drug safety updates for the given search term.
642
+
643
+ Args:
644
+ - args_: the arguments for the spock pipeline
645
+ - queue_out: the output queue for the scraped data
646
+ """
647
+ term = args_.term
648
+ max_docs_src = args_.max_docs_src
649
+ self.logger.debug(f'starting scraping the source={self._source} with term={term}')
650
+
651
+ # Search for relevant documents with a given term
652
+ search_results = await self._mhra_document_search(term=term)
653
+ n_items = len(search_results.items)
654
+ if n_items > max_docs_src:
655
+ self.logger.warning(f'from #items={n_items} only max_docs_src={max_docs_src} will be parsed')
656
+ items = search_results.items[:max_docs_src]
657
+
658
+ # Parse the documents
659
+ data = await self._parse_items(items=items)
660
+ n_data = len(data)
661
+ if n_data > max_docs_src:
662
+ self.logger.warning(f'the #items={n_items} were chunked into #documents={n_data} from where only max_docs_src={max_docs_src} will be added to the queue')
663
+ data = data[:max_docs_src]
664
+
665
+ # Add documents to the queue
666
+ for i, doc in enumerate(data):
667
+ id_ = f'{self._source}_{i}'
668
+ item = QueueItem(id_=id_, doc=doc)
669
+ await queue_out.put(item)
670
+
671
+ self.logger.info(f'retrieved #docs={len(data)} in source={self._source} for term={term}')
672
+
673
+
674
+ class FDAScraper(Scraper):
675
+ """Class for scraping data from the Food and Drug Administration.
676
+
677
+ The scraper uses the the same API as the web interface fo the FDA to search for relevant documents. By default the search applies the following filter:
678
+ - sorting by highest relevance
679
+ - filter for results from the Center of Drug Evaluation and Research
680
+ - filter for English language
681
+ - filter for Drugs
682
+ """
683
+
684
+ _source = 'fda'
685
+ _source_url = 'https://www.fda.gov'
686
+ _source_favicon_url = 'https://www.fda.gov/favicon.ico'
687
+
688
+ _search_template = (
689
+ 'https://www.fda.gov/search?s={term}'
690
+ '&items_per_page=10'
691
+ '&sort_bef_combine=rel_DESC' # Sort by relevance
692
+ '&f%5B0%5D=center%3A815' # Filter for the Center for Drug Evaluation and Research
693
+ '&f%5B1%5D=language%3A1404' # Filter for English language
694
+ '&f%5B2%5D=prod%3A2312' # Filter for the Drugs section
695
+ )
696
+ _language = 'en'
697
+
698
+
699
+ def _extract_search_results_count(self, soup: BeautifulSoup) -> int | None:
700
+ """Extract the number of search results from the search info section."""
701
+ parent = soup.find('div', class_='lcds-search-filters__info')
702
+ if parent is not None:
703
+ div = parent.find('div', class_='view-header')
704
+ match = re.search(r'of (\d+) entr[y|ies]', div.text)
705
+ if match:
706
+ return int(match.group(1))
707
+ self.logger.warning('no search info found')
708
+ return None
709
+
710
+ def _extract_number_of_pages(self, soup: BeautifulSoup) -> int | None:
711
+ """Extract the number of pages from the search results."""
712
+ nav = soup.find('nav', class_=['pager-nav', 'text-center'])
713
+ if nav is not None:
714
+ last_page = nav.find('li', class_=['pager__item', 'pager__item--last'])
715
+ a = last_page.find('a')
716
+ if a and a.has_attr('href'):
717
+ href = a['href']
718
+ match = re.search(r'&page=(\d+)', href)
719
+ if match:
720
+ return int(match.group(1)) + 1
721
+ self.logger.warning('no pager found')
722
+ return None
723
+
724
+ @staticmethod
725
+ def _extract_search_item_divs(soup: BeautifulSoup) -> List[Tag]:
726
+ """Extract the divs containing the search results."""
727
+ parent = soup.find('div', class_='view-content')
728
+ return parent.find_all('div', recursive=False)
729
+
730
+ async def _fda_document_search(self, term: str, max_docs_src: int) -> SearchResults:
731
+ """Search the FDA database for documents with a given term."""
732
+
733
+ # Get search results
734
+ url = self._search_template.format(term=term)
735
+ self.logger.debug(f'search fda database with url={url}')
736
+ content = await self._aiohttp_get_html(url=url)
737
+ soup = BeautifulSoup(content, 'html.parser')
738
+
739
+ # Get the number of search results and the number of pages
740
+ count = self._extract_search_results_count(soup=soup)
741
+ n_pages = self._extract_number_of_pages(soup=soup)
742
+
743
+ # Extract the divs containing the search results
744
+ items = []
745
+ if count is not None and count > 0:
746
+ # Extract items from page=0
747
+ items_from_page = self._extract_search_item_divs(soup=soup)
748
+ items.extend(items_from_page)
749
+
750
+ # Extract items from page=1, 2, ...
751
+ if n_pages is not None and n_pages > 1:
752
+ for i in range(1, n_pages):
753
+ url = f'{url}&page={i}'
754
+ content = await self._aiohttp_get_html(url=url)
755
+ soup = BeautifulSoup(content, 'html.parser')
756
+
757
+ items_from_page = self._extract_search_item_divs(soup=soup)
758
+ items.extend(items_from_page)
759
+
760
+ if len(items) >= max_docs_src:
761
+ self.logger.debug(f'found #items={len(items)} in #pages={i+1}')
762
+ break
763
+
764
+ # Check for extraction mismatch
765
+ if len(items) != count:
766
+ self.logger.warning(f'mismatch #items={len(items)} and the total count={count}')
767
+
768
+ self.logger.debug(f'extracted #items={len(items)} in #pages={n_pages}')
769
+ return SearchResults(count=count, n_pages=n_pages, items=items)
770
+
771
+ @staticmethod
772
+ def _extract_title(main: Tag) -> str | None:
773
+ """Extract the title of the document."""
774
+ header = main.find('header', class_=['row', 'content-header'])
775
+ if header is not None:
776
+ h1 = header.find('h1', class_=['content-title', 'text-center'])
777
+ if h1 is not None:
778
+ return h1.text
779
+ return None
780
+
781
+ @staticmethod
782
+ def _extract_text(main: Tag) -> str | None:
783
+ """Extract the text from the document."""
784
+ body = main.find('div', attrs={'class': 'col-md-8 col-md-push-2', 'role': 'main'})
785
+ return body.get_text() if body is not None else None
786
+
787
+ @staticmethod
788
+ def _extract_date(main: Tag) -> datetime | None:
789
+ """Extract the publication date of the document."""
790
+ dl = main.find('dl', class_='lcds-description-list--grid')
791
+ if dl is not None:
792
+ dd = dl.find('dd', class_='cell-2_2')
793
+ time_tag = dd.find('time')
794
+ else:
795
+ time_tag = main.find('time')
796
+ if time_tag and time_tag.has_attr('datetime'):
797
+ return datetime.fromisoformat(time_tag['datetime'])
798
+ return None
799
+
800
+ async def _parse_item_page(self, url: str) -> List[Document]:
801
+ content = await self._aiohttp_get_html(url=url)
802
+ soup = BeautifulSoup(content, 'html.parser')
803
+ main = soup.find('main')
804
+
805
+ # Extract the relevant information from the document
806
+ title = self._extract_title(main=main)
807
+ text = self._extract_text(main=main)
808
+ publication_date = self._extract_date(main=main)
809
+
810
+ # Split long texts into chunks
811
+ texts = self.split_text_into_chunks(text=text)
812
+
813
+ # Create the Document object(s)
814
+ data = []
815
+ for text in texts:
816
+ document = Document(
817
+ id_=f'{self._source_url} {title} {text} {publication_date}',
818
+ text=text,
819
+ source=self._source,
820
+ title=title,
821
+ url=url,
822
+ source_url=self._source_url,
823
+ source_favicon_url=self._source_favicon_url,
824
+ language=self._language,
825
+ publication_date=publication_date,
826
+ )
827
+ data.append(document)
828
+ return data
829
+
830
+ async def _parse_items(self, items: List[Tag]) -> List[Document]:
831
+ """From a list of divs containing the search results, extract the relevant information and parse it into a list
832
+ of :class:`Document` objects.
833
+ """
834
+ data = []
835
+ for i, tag in enumerate(items):
836
+ link = tag.find('a', href=True)
837
+ if link is None:
838
+ self.logger.warning(f'no link found for item {i}')
839
+ continue
840
+ url = self._source_url + link['href']
841
+
842
+ # Create the Document object
843
+ self.logger.debug(f'parsing item with url={url}')
844
+ page_data = await self._parse_item_page(url=url)
845
+ data.extend(page_data)
846
+ self.logger.debug(f'created #docs={len(data)}')
847
+ return data
848
+
849
+ async def apply(self, args_: Namespace, queue_out: asyncio.Queue) -> None:
850
+ term = args_.term
851
+ max_docs_src = args_.max_docs_src
852
+ self.logger.debug(f'starting scraping the source={self._source} with term={term}')
853
+
854
+ # Search for relevant documents with a given term
855
+ search_results = await self._fda_document_search(term=term, max_docs_src=max_docs_src)
856
+ n_items = len(search_results.items)
857
+ if n_items > max_docs_src:
858
+ self.logger.warning(f'from #items={n_items} only max_docs_src={max_docs_src} will be parsed')
859
+ items = search_results.items[:max_docs_src]
860
+
861
+ # Parse the documents
862
+ data = await self._parse_items(items=items)
863
+ n_data = len(data)
864
+ if n_data > max_docs_src:
865
+ self.logger.warning(f'the #items={n_items} were chunked into #documents={n_data} from where only max_docs_src={max_docs_src} will be added to the queue')
866
+ data = data[:max_docs_src]
867
+
868
+ # Add documents to the queue
869
+ for i, doc in enumerate(data):
870
+ id_ = f'{self._source}_{i}'
871
+ item = QueueItem(id_=id_, doc=doc)
872
+ await queue_out.put(item)
873
+
874
+ self.logger.info(f'retrieved #docs={len(data)} in source={self._source} for term={term}')
875
+
876
+
877
+ _SCRAPERS = [PubmedScraper, EMAScraper, MHRAScraper, FDAScraper]
878
+ _SOURCE_TO_SCRAPER = {src: scr for src, scr in zip(SCRAPING_SOURCES, _SCRAPERS)}
879
+ if not len(_SCRAPERS) == len(SCRAPING_SOURCES):
880
+ raise ValueError("number of scrapers and sources do not match")
881
+
882
+ async def _scraping(args_: Namespace, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> None:
883
+ """Pop a source (str) from the input queue, perform the scraping task with the given term, and put the results in
884
+ the output queue until the input queue is empty.
885
+
886
+ Args:
887
+ - args_: the arguments for the spock pipeline
888
+ - queue_in: the input queue containing the sources to scrape
889
+ - queue_out: the output queue for the scraped data
890
+ """
891
+
892
+ while True:
893
+ # Get source from input queue
894
+ source = await queue_in.get()
895
+
896
+ # Check stopping condition
897
+ if source is None:
898
+ queue_in.task_done()
899
+ break
900
+
901
+ # Get the scraper and apply it to the term
902
+ scraper = _SOURCE_TO_SCRAPER.get(source) # type: type[Scraper]
903
+ if scraper is None:
904
+ logger.error(f'unknown source={source}')
905
+ queue_in.task_done()
906
+ break
907
+
908
+ try:
909
+ await scraper().apply(args_=args_, queue_out=queue_out)
910
+ except Exception as e:
911
+ logger.error(f'error during scraping for source={source} and term={args_.term}: {e}')
912
+ queue_in.task_done()
913
+ continue
914
+ queue_in.task_done()
915
+
916
+
917
+ def create_tasks(args_: Namespace, queue_in: asyncio.Queue, queue_out: asyncio.Queue) -> List[asyncio.Task]:
918
+ """Create the asyncio scraping tasks."""
919
+ n_tasks = args_.n_scp_tasks
920
+ logger.info(f'setting up {n_tasks} scraping task(s) for source(s)={args_.source}')
921
+ tasks = [asyncio.create_task(_scraping(args_=args_, queue_in=queue_in, queue_out=queue_out)) for _ in range(n_tasks)]
922
+ return tasks