Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/klinic-hackupc/klinic into main
Browse files- .DS_Store +0 -0
- MATLAB/visualize_app.mlapp +0 -0
- MATLAB/visualize_connectedNodes_continuous.m +6 -0
- app.py +29 -38
- img_klinic.jpeg +0 -0
- utils.py +47 -1
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
MATLAB/visualize_app.mlapp
CHANGED
|
Binary files a/MATLAB/visualize_app.mlapp and b/MATLAB/visualize_app.mlapp differ
|
|
|
MATLAB/visualize_connectedNodes_continuous.m
CHANGED
|
@@ -18,6 +18,12 @@ function visualize_connectedNodes_continuous()
|
|
| 18 |
else
|
| 19 |
connectionsMap(char_node) = char_connectedNode;
|
| 20 |
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
end
|
| 22 |
|
| 23 |
% Loop for continuous interaction
|
|
|
|
| 18 |
else
|
| 19 |
connectionsMap(char_node) = char_connectedNode;
|
| 20 |
end
|
| 21 |
+
|
| 22 |
+
if isKey(connectionsMap, char_connectedNode)
|
| 23 |
+
connectionsMap(char_connectedNode) = [connectionsMap(char_connectedNode), '|', char_node];
|
| 24 |
+
else
|
| 25 |
+
connectionsMap(char_connectedNode) = char_node;
|
| 26 |
+
end
|
| 27 |
end
|
| 28 |
|
| 29 |
% Loop for continuous interaction
|
app.py
CHANGED
|
@@ -12,7 +12,8 @@ from utils import (
|
|
| 12 |
get_similarities_among_diseases_uris,
|
| 13 |
augment_the_set_of_diseaces,
|
| 14 |
get_clinical_trials_related_to_diseases,
|
| 15 |
-
get_clinical_records_by_ids
|
|
|
|
| 16 |
)
|
| 17 |
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
|
| 18 |
import json
|
|
@@ -36,11 +37,15 @@ CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}
|
|
| 36 |
engine = create_engine(CONNECTION_STRING)
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
with st.container(): # user input
|
| 40 |
col1, col2 = st.columns((6, 1))
|
| 41 |
|
| 42 |
with col1:
|
| 43 |
-
description_input = st.text_area(label="Enter
|
| 44 |
|
| 45 |
with col2:
|
| 46 |
st.text('') # dummy to center vertically
|
|
@@ -60,6 +65,8 @@ with st.container():
|
|
| 60 |
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
|
| 61 |
description_input, encoder
|
| 62 |
)
|
|
|
|
|
|
|
| 63 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
| 64 |
status.write("Getting the similarities among the diseases to filter out less promising ones...")
|
| 65 |
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
|
|
@@ -78,7 +85,7 @@ with st.container():
|
|
| 78 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
| 79 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
| 80 |
)
|
| 81 |
-
status.json(json_of_clinical_trials)
|
| 82 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
| 83 |
status.write("Getting a summary of the clinical trials...")
|
| 84 |
response = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
|
@@ -91,13 +98,23 @@ with st.container():
|
|
| 91 |
status.write(f'Response from LLM tagging: {response}')
|
| 92 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
| 93 |
status.update(label="Done!", state="complete")
|
| 94 |
-
|
| 95 |
show_graph = True
|
| 96 |
|
| 97 |
|
| 98 |
# graph
|
| 99 |
with st.container():
|
| 100 |
if show_graph:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# TODO actual graph
|
| 102 |
graph_of_diseases = agraph(
|
| 103 |
nodes=[
|
|
@@ -147,39 +164,13 @@ with st.container():
|
|
| 147 |
# TODO replace mock data
|
| 148 |
with open("mock_trial.json") as f:
|
| 149 |
d = json.load(f)
|
| 150 |
-
for i in range(0,
|
| 151 |
trials.append(d)
|
| 152 |
|
| 153 |
-
for trial in trials
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
|
| 161 |
-
st.write(brief_summary)
|
| 162 |
-
|
| 163 |
-
status_module = {
|
| 164 |
-
"Status": trial["protocolSection"]["statusModule"]["overallStatus"],
|
| 165 |
-
"Status Date": trial["protocolSection"]["statusModule"][
|
| 166 |
-
"statusVerifiedDate"
|
| 167 |
-
],
|
| 168 |
-
}
|
| 169 |
-
st.write("###### Status")
|
| 170 |
-
st.table(status_module)
|
| 171 |
-
|
| 172 |
-
design_module = {
|
| 173 |
-
"Study Type": trial["protocolSection"]["designModule"]["studyType"],
|
| 174 |
-
# "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array
|
| 175 |
-
"Allocation": trial["protocolSection"]["designModule"]["designInfo"][
|
| 176 |
-
"allocation"
|
| 177 |
-
],
|
| 178 |
-
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
|
| 179 |
-
"count"
|
| 180 |
-
],
|
| 181 |
-
}
|
| 182 |
-
st.write("###### Design")
|
| 183 |
-
st.table(design_module)
|
| 184 |
-
|
| 185 |
-
# TODO more modules?
|
|
|
|
| 12 |
get_similarities_among_diseases_uris,
|
| 13 |
augment_the_set_of_diseaces,
|
| 14 |
get_clinical_trials_related_to_diseases,
|
| 15 |
+
get_clinical_records_by_ids,
|
| 16 |
+
render_trial_details
|
| 17 |
)
|
| 18 |
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
|
| 19 |
import json
|
|
|
|
| 37 |
engine = create_engine(CONNECTION_STRING)
|
| 38 |
|
| 39 |
|
| 40 |
+
st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
|
| 41 |
+
st.title("Klìnic", help="AI-powered clinical trial search engine")
|
| 42 |
+
st.subheader("Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights.")
|
| 43 |
+
|
| 44 |
with st.container(): # user input
|
| 45 |
col1, col2 = st.columns((6, 1))
|
| 46 |
|
| 47 |
with col1:
|
| 48 |
+
description_input = st.text_area(label="Enter a disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
|
| 49 |
|
| 50 |
with col2:
|
| 51 |
st.text('') # dummy to center vertically
|
|
|
|
| 65 |
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
|
| 66 |
description_input, encoder
|
| 67 |
)
|
| 68 |
+
status.info(f'Found {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
|
| 69 |
+
status.json(diseases_related_to_the_user_text, expanded=False)
|
| 70 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
| 71 |
status.write("Getting the similarities among the diseases to filter out less promising ones...")
|
| 72 |
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
|
|
|
|
| 85 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
| 86 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
| 87 |
)
|
| 88 |
+
status.json(json_of_clinical_trials, expanded=False)
|
| 89 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
| 90 |
status.write("Getting a summary of the clinical trials...")
|
| 91 |
response = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
|
|
|
| 98 |
status.write(f'Response from LLM tagging: {response}')
|
| 99 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
| 100 |
status.update(label="Done!", state="complete")
|
| 101 |
+
status.balloons()
|
| 102 |
show_graph = True
|
| 103 |
|
| 104 |
|
| 105 |
# graph
|
| 106 |
with st.container():
|
| 107 |
if show_graph:
|
| 108 |
+
st.info(
|
| 109 |
+
"""This is a graph of the relevant diseases that we found, based on the description that you entered. The diseases are connected by edges if they are similar to each other. The color of the edges represents the similarity of the diseases.
|
| 110 |
+
|
| 111 |
+
We use the embeddings of the diseases to determine the similarity between them. The embeddings are generated using a Representation Learning algorithm that learns the topological relations among the nodes in the graph, depending on how they are connected. We utilize the (PyKeen)[https://github.com/pykeen/pykeen] implementation of TransH to train an embedding model.
|
| 112 |
+
|
| 113 |
+
(TransH)[https://ojs.aaai.org/index.php/AAAI/article/view/8870] utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
|
| 114 |
+
|
| 115 |
+
Specifically, it optimizes the following cost function:
|
| 116 |
+
$$"""
|
| 117 |
+
)
|
| 118 |
# TODO actual graph
|
| 119 |
graph_of_diseases = agraph(
|
| 120 |
nodes=[
|
|
|
|
| 164 |
# TODO replace mock data
|
| 165 |
with open("mock_trial.json") as f:
|
| 166 |
d = json.load(f)
|
| 167 |
+
for i in range(0, 8):
|
| 168 |
trials.append(d)
|
| 169 |
|
| 170 |
+
tab_titles = [f"{trial['protocolSection']['identificationModule']['nctId']}" for trial in trials]
|
| 171 |
+
|
| 172 |
+
tabs = st.tabs(tab_titles)
|
| 173 |
+
|
| 174 |
+
for i in range(0, len(tabs)):
|
| 175 |
+
with tabs[i]:
|
| 176 |
+
render_trial_details(trials[i])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_klinic.jpeg
ADDED
|
utils.py
CHANGED
|
@@ -4,6 +4,7 @@ import os
|
|
| 4 |
from sqlalchemy import create_engine, text
|
| 5 |
import requests
|
| 6 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 7 |
|
| 8 |
username = "demo"
|
| 9 |
password = "demo"
|
|
@@ -181,7 +182,7 @@ def get_clinical_trials_related_to_diseases(
|
|
| 181 |
with engine.connect() as conn:
|
| 182 |
with conn.begin():
|
| 183 |
sql = f"""
|
| 184 |
-
SELECT TOP
|
| 185 |
FROM Test.ClinicalTrials d
|
| 186 |
ORDER BY distance DESC
|
| 187 |
"""
|
|
@@ -190,6 +191,51 @@ def get_clinical_trials_related_to_diseases(
|
|
| 190 |
|
| 191 |
return [{"nct_id": row[0], "distance": row[1]} for row in data]
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
if __name__ == "__main__":
|
| 195 |
username = "demo"
|
|
|
|
| 4 |
from sqlalchemy import create_engine, text
|
| 5 |
import requests
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
+
import streamlit as st
|
| 8 |
|
| 9 |
username = "demo"
|
| 10 |
password = "demo"
|
|
|
|
| 182 |
with engine.connect() as conn:
|
| 183 |
with conn.begin():
|
| 184 |
sql = f"""
|
| 185 |
+
SELECT TOP 10 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
|
| 186 |
FROM Test.ClinicalTrials d
|
| 187 |
ORDER BY distance DESC
|
| 188 |
"""
|
|
|
|
| 191 |
|
| 192 |
return [{"nct_id": row[0], "distance": row[1]} for row in data]
|
| 193 |
|
| 194 |
+
def to_capitalized_case(string: str) -> str:
|
| 195 |
+
string = string.replace("_", " ")
|
| 196 |
+
if string.isupper():
|
| 197 |
+
return string[0] + string[1:].lower()
|
| 198 |
+
|
| 199 |
+
def list_to_capitalized_case(strings: List[str]) -> str:
|
| 200 |
+
strings = [to_capitalized_case(s) for s in strings]
|
| 201 |
+
return ", ".join(strings)
|
| 202 |
+
|
| 203 |
+
def render_trial_details(trial: dict) -> None:
|
| 204 |
+
# TODO: handle key errors for all cases (→ do not render)
|
| 205 |
+
|
| 206 |
+
official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
|
| 207 |
+
st.write(f"##### {official_title}")
|
| 208 |
+
|
| 209 |
+
brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
|
| 210 |
+
st.write(brief_summary)
|
| 211 |
+
|
| 212 |
+
status_module = {
|
| 213 |
+
"Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]),
|
| 214 |
+
"Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"],
|
| 215 |
+
"Has Results": trial["hasResults"]
|
| 216 |
+
}
|
| 217 |
+
st.write("###### Status")
|
| 218 |
+
st.table(status_module)
|
| 219 |
+
|
| 220 |
+
design_module = {
|
| 221 |
+
"Study Type": to_capitalized_case(trial["protocolSection"]["designModule"]["studyType"]),
|
| 222 |
+
"Phases": list_to_capitalized_case(trial["protocolSection"]["designModule"]["phases"]),
|
| 223 |
+
"Allocation": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["allocation"]),
|
| 224 |
+
"Primary Purpose": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]),
|
| 225 |
+
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"]["count"],
|
| 226 |
+
"Masking": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["masking"]),
|
| 227 |
+
"Who Masked": list_to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["whoMasked"])
|
| 228 |
+
}
|
| 229 |
+
st.write("###### Design")
|
| 230 |
+
st.table(design_module)
|
| 231 |
+
|
| 232 |
+
interventions_module = {}
|
| 233 |
+
for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]:
|
| 234 |
+
name = intervention["name"]
|
| 235 |
+
desc = intervention["description"]
|
| 236 |
+
interventions_module[name] = desc
|
| 237 |
+
st.write("###### Interventions")
|
| 238 |
+
st.table(interventions_module)
|
| 239 |
|
| 240 |
if __name__ == "__main__":
|
| 241 |
username = "demo"
|