1-ARIjitS commited on
Commit
772bbc6
·
2 Parent(s): aa656dd ec6a815

Merge branch 'main' of https://huggingface.co/spaces/klinic-hackupc/klinic into main

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
MATLAB/visualize_app.mlapp CHANGED
Binary files a/MATLAB/visualize_app.mlapp and b/MATLAB/visualize_app.mlapp differ
 
MATLAB/visualize_connectedNodes_continuous.m CHANGED
@@ -18,6 +18,12 @@ function visualize_connectedNodes_continuous()
18
  else
19
  connectionsMap(char_node) = char_connectedNode;
20
  end
 
 
 
 
 
 
21
  end
22
 
23
  % Loop for continuous interaction
 
18
  else
19
  connectionsMap(char_node) = char_connectedNode;
20
  end
21
+
22
+ if isKey(connectionsMap, char_connectedNode)
23
+ connectionsMap(char_connectedNode) = [connectionsMap(char_connectedNode), '|', char_node];
24
+ else
25
+ connectionsMap(char_connectedNode) = char_node;
26
+ end
27
  end
28
 
29
  % Loop for continuous interaction
app.py CHANGED
@@ -12,7 +12,8 @@ from utils import (
12
  get_similarities_among_diseases_uris,
13
  augment_the_set_of_diseaces,
14
  get_clinical_trials_related_to_diseases,
15
- get_clinical_records_by_ids
 
16
  )
17
  from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
18
  import json
@@ -36,11 +37,15 @@ CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}
36
  engine = create_engine(CONNECTION_STRING)
37
 
38
 
 
 
 
 
39
  with st.container(): # user input
40
  col1, col2 = st.columns((6, 1))
41
 
42
  with col1:
43
- description_input = st.text_area(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
44
 
45
  with col2:
46
  st.text('') # dummy to center vertically
@@ -60,6 +65,8 @@ with st.container():
60
  diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
61
  description_input, encoder
62
  )
 
 
63
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
64
  status.write("Getting the similarities among the diseases to filter out less promising ones...")
65
  diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
@@ -78,7 +85,7 @@ with st.container():
78
  json_of_clinical_trials = get_clinical_records_by_ids(
79
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
80
  )
81
- status.json(json_of_clinical_trials)
82
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
83
  status.write("Getting a summary of the clinical trials...")
84
  response = get_short_summary_out_of_json_files(json_of_clinical_trials)
@@ -91,13 +98,23 @@ with st.container():
91
  status.write(f'Response from LLM tagging: {response}')
92
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
93
  status.update(label="Done!", state="complete")
94
- time.sleep(1)
95
  show_graph = True
96
 
97
 
98
  # graph
99
  with st.container():
100
  if show_graph:
 
 
 
 
 
 
 
 
 
 
101
  # TODO actual graph
102
  graph_of_diseases = agraph(
103
  nodes=[
@@ -147,39 +164,13 @@ with st.container():
147
  # TODO replace mock data
148
  with open("mock_trial.json") as f:
149
  d = json.load(f)
150
- for i in range(0, 5):
151
  trials.append(d)
152
 
153
- for trial in trials:
154
- with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"):
155
- official_title = trial["protocolSection"]["identificationModule"][
156
- "officialTitle"
157
- ]
158
- st.write(f"##### {official_title}")
159
-
160
- brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
161
- st.write(brief_summary)
162
-
163
- status_module = {
164
- "Status": trial["protocolSection"]["statusModule"]["overallStatus"],
165
- "Status Date": trial["protocolSection"]["statusModule"][
166
- "statusVerifiedDate"
167
- ],
168
- }
169
- st.write("###### Status")
170
- st.table(status_module)
171
-
172
- design_module = {
173
- "Study Type": trial["protocolSection"]["designModule"]["studyType"],
174
- # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array
175
- "Allocation": trial["protocolSection"]["designModule"]["designInfo"][
176
- "allocation"
177
- ],
178
- "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
179
- "count"
180
- ],
181
- }
182
- st.write("###### Design")
183
- st.table(design_module)
184
-
185
- # TODO more modules?
 
12
  get_similarities_among_diseases_uris,
13
  augment_the_set_of_diseaces,
14
  get_clinical_trials_related_to_diseases,
15
+ get_clinical_records_by_ids,
16
+ render_trial_details
17
  )
18
  from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
19
  import json
 
37
  engine = create_engine(CONNECTION_STRING)
38
 
39
 
40
+ st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
41
+ st.title("Klìnic", help="AI-powered clinical trial search engine")
42
+ st.subheader("Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights.")
43
+
44
  with st.container(): # user input
45
  col1, col2 = st.columns((6, 1))
46
 
47
  with col1:
48
+ description_input = st.text_area(label="Enter a disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
49
 
50
  with col2:
51
  st.text('') # dummy to center vertically
 
65
  diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
66
  description_input, encoder
67
  )
68
+ status.info(f'Found {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
69
+ status.json(diseases_related_to_the_user_text, expanded=False)
70
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
71
  status.write("Getting the similarities among the diseases to filter out less promising ones...")
72
  diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
 
85
  json_of_clinical_trials = get_clinical_records_by_ids(
86
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
87
  )
88
+ status.json(json_of_clinical_trials, expanded=False)
89
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
90
  status.write("Getting a summary of the clinical trials...")
91
  response = get_short_summary_out_of_json_files(json_of_clinical_trials)
 
98
  status.write(f'Response from LLM tagging: {response}')
99
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
100
  status.update(label="Done!", state="complete")
101
+ status.balloons()
102
  show_graph = True
103
 
104
 
105
  # graph
106
  with st.container():
107
  if show_graph:
108
+ st.info(
109
+ """This is a graph of the relevant diseases that we found, based on the description that you entered. The diseases are connected by edges if they are similar to each other. The color of the edges represents the similarity of the diseases.
110
+
111
+ We use the embeddings of the diseases to determine the similarity between them. The embeddings are generated using a Representation Learning algorithm that learns the topological relations among the nodes in the graph, depending on how they are connected. We utilize the (PyKeen)[https://github.com/pykeen/pykeen] implementation of TransH to train an embedding model.
112
+
113
+ (TransH)[https://ojs.aaai.org/index.php/AAAI/article/view/8870] utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
114
+
115
+ Specifically, it optimizes the following cost function:
116
+ $$"""
117
+ )
118
  # TODO actual graph
119
  graph_of_diseases = agraph(
120
  nodes=[
 
164
  # TODO replace mock data
165
  with open("mock_trial.json") as f:
166
  d = json.load(f)
167
+ for i in range(0, 8):
168
  trials.append(d)
169
 
170
+ tab_titles = [f"{trial['protocolSection']['identificationModule']['nctId']}" for trial in trials]
171
+
172
+ tabs = st.tabs(tab_titles)
173
+
174
+ for i in range(0, len(tabs)):
175
+ with tabs[i]:
176
+ render_trial_details(trials[i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
img_klinic.jpeg ADDED
utils.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  from sqlalchemy import create_engine, text
5
  import requests
6
  from sentence_transformers import SentenceTransformer
 
7
 
8
  username = "demo"
9
  password = "demo"
@@ -181,7 +182,7 @@ def get_clinical_trials_related_to_diseases(
181
  with engine.connect() as conn:
182
  with conn.begin():
183
  sql = f"""
184
- SELECT TOP 5 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
185
  FROM Test.ClinicalTrials d
186
  ORDER BY distance DESC
187
  """
@@ -190,6 +191,51 @@ def get_clinical_trials_related_to_diseases(
190
 
191
  return [{"nct_id": row[0], "distance": row[1]} for row in data]
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  if __name__ == "__main__":
195
  username = "demo"
 
4
  from sqlalchemy import create_engine, text
5
  import requests
6
  from sentence_transformers import SentenceTransformer
7
+ import streamlit as st
8
 
9
  username = "demo"
10
  password = "demo"
 
182
  with engine.connect() as conn:
183
  with conn.begin():
184
  sql = f"""
185
+ SELECT TOP 10 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
186
  FROM Test.ClinicalTrials d
187
  ORDER BY distance DESC
188
  """
 
191
 
192
  return [{"nct_id": row[0], "distance": row[1]} for row in data]
193
 
194
+ def to_capitalized_case(string: str) -> str:
195
+ string = string.replace("_", " ")
196
+ if string.isupper():
197
+ return string[0] + string[1:].lower()
198
+
199
+ def list_to_capitalized_case(strings: List[str]) -> str:
200
+ strings = [to_capitalized_case(s) for s in strings]
201
+ return ", ".join(strings)
202
+
203
+ def render_trial_details(trial: dict) -> None:
204
+ # TODO: handle key errors for all cases (→ do not render)
205
+
206
+ official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
207
+ st.write(f"##### {official_title}")
208
+
209
+ brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
210
+ st.write(brief_summary)
211
+
212
+ status_module = {
213
+ "Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]),
214
+ "Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"],
215
+ "Has Results": trial["hasResults"]
216
+ }
217
+ st.write("###### Status")
218
+ st.table(status_module)
219
+
220
+ design_module = {
221
+ "Study Type": to_capitalized_case(trial["protocolSection"]["designModule"]["studyType"]),
222
+ "Phases": list_to_capitalized_case(trial["protocolSection"]["designModule"]["phases"]),
223
+ "Allocation": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["allocation"]),
224
+ "Primary Purpose": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]),
225
+ "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"]["count"],
226
+ "Masking": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["masking"]),
227
+ "Who Masked": list_to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["whoMasked"])
228
+ }
229
+ st.write("###### Design")
230
+ st.table(design_module)
231
+
232
+ interventions_module = {}
233
+ for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]:
234
+ name = intervention["name"]
235
+ desc = intervention["description"]
236
+ interventions_module[name] = desc
237
+ st.write("###### Interventions")
238
+ st.table(interventions_module)
239
 
240
  if __name__ == "__main__":
241
  username = "demo"