cloud-sean commited on
Commit
fd3f788
·
1 Parent(s): 385ee99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -121
app.py CHANGED
@@ -3,58 +3,25 @@ import streamlit as st
3
  import os
4
  import json
5
  import time
6
- from annotated_text import annotated_text
7
-
8
-
9
-
10
- dax_input = st.text_area("DAX input")
11
-
12
-
13
- def generate_annotated_array(text, json_data):
14
- """
15
- Generate an array of strings based on annotations from the JSON data.
16
-
17
- Parameters:
18
- - text (str): The input text to be annotated.
19
- - json_data (dict): The JSON data containing annotations.
20
-
21
- Returns:
22
- - list: An array of strings with separate elements for each annotation.
23
- """
24
- entities = json_data['documents'][0]['entities']
25
- output = []
26
- index = 0
27
- buffer_text = ""
28
-
29
- while index < len(text):
30
- # Find the entity that matches the current position
31
- entity = next((e for e in entities if e['offset'] == index), None)
32
-
33
- if entity:
34
- # If there's buffer_text, add it to the output
35
- if buffer_text:
36
- output.append(buffer_text)
37
- buffer_text = ""
38
-
39
- # Add the annotated entity to the output
40
- output.append("[" + entity['text'] + ":" + entity['category'] + "]")
41
- index += entity['length']
42
- else:
43
- # If no entity is found, add the character to buffer_text
44
- buffer_text += text[index]
45
- index += 1
46
-
47
- # Add any remaining buffer_text to the output
48
- if buffer_text:
49
- output.append(buffer_text)
50
-
51
- return output
52
-
53
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
-
56
- import requests
57
- import time
 
 
58
 
59
  def analyze_healthcare_text(text):
60
  # Endpoint, headers and subscription key
@@ -80,8 +47,6 @@ def analyze_healthcare_text(text):
80
 
81
  # Making the initial POST request
82
  response = requests.post(f"{base_url}?api-version=2022-10-01-preview", headers=headers, json=data)
83
-
84
- time.sleep(10)
85
 
86
  # Get the operation-location from the response header
87
  operation_location = response.headers.get('operation-location')
@@ -89,103 +54,252 @@ def analyze_healthcare_text(text):
89
  # Extract JOB-ID from the operation-location
90
  job_id = operation_location.split('/')[-1].split('?')[0]
91
 
 
92
  # Make a subsequent GET request to retrieve the results using the JOB-ID
93
- result_response = requests.get(f"{base_url}/{job_id}?api-version=2022-10-01-preview", headers=headers)
 
 
 
 
 
 
94
 
95
  # Return the JSON response from the GET request
96
  result = result_response.json()
97
 
98
  return result
99
 
100
-
101
-
102
- def convert_to_annotated_text(input_list):
103
- """
104
- Convert a list with annotated content into a nested list suitable for annotated_text format.
 
 
 
 
 
 
 
 
105
 
106
- Args:
107
- - input_list (list): The list with content and annotations in format '[text:annotation]'.
 
 
108
 
109
- Returns:
110
- - list: A nested list in the annotated_text format.
111
- """
112
- annotated_list = []
113
- temp_group = []
114
-
115
- for item in input_list:
116
- # Check if the item is an annotation
117
- if item.startswith('[') and item.endswith(']'):
118
- content = item[1:-1].split(':')
119
- temp_group.append((content[0], content[1]))
120
- else:
121
- if temp_group: # if there are items in the temporary group
122
- annotated_list.append(temp_group)
123
- temp_group = []
124
- annotated_list.append(item)
125
 
126
- # Add any remaining items in the temporary group to the final list
127
- if temp_group:
128
- annotated_list.append(temp_group)
129
 
130
- return annotated_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- if st.button("Analyze"):
134
- text = dax_input
135
- json_analysis = analyze_healthcare_text(text)
136
- json_analysis = json_analysis["tasks"]["items"][0]["results"]
137
- # save json analysis as a file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
 
 
 
 
 
139
 
140
- new_text = generate_annotated_array(text, json_analysis)
141
- new_text = convert_to_annotated_text(new_text)
142
- annotated_text(new_text)
143
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
 
145
 
 
 
146
 
147
- st.title("Nuance DAX Copilot")
 
148
 
149
- os.environ["OPENAI_API_BASE"] = openai.api_type = "azure"
150
- os.environ["OPENAI_API_BASE"] = openai.api_base = "https://eastus-openai-sean.openai.azure.com/"
151
- os.environ["OPENAI_API_VERSION"] = openai.api_version = "2023-03-15-preview"
152
 
153
- openai.api_key = os.environ["OPENAI_API_KEY"]
154
- openai.api_version = os.environ["OPENAI_API_VERSION"]
155
- openai.api_base = os.environ["OPENAI_API_BASE"]
156
- os.environ["OPENAI_API_VERSION"] = openai.api_version = "2023-03-15-preview"
157
 
 
158
 
159
- if "messages" not in st.session_state:
160
- st.session_state.messages = [{"role":"system","content":"You are an AI assistant that ansswers questions about patient encounters. You are not a doctor and should not diagnose or treat patients. However, you can suggest common practices and help doctors with their questions that will help them make better decisions. \ Use only the information below: \n Patient Note / Encounter Summary: \n"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
 
 
 
162
 
163
- for message in st.session_state.messages:
164
- with st.chat_message(message["role"]):
165
- st.markdown(message["content"])
166
 
 
167
 
 
168
 
169
 
 
 
 
 
 
170
 
 
 
 
171
 
172
- if prompt := st.chat_input("Nuance DAX Copilot?"):
173
- st.session_state.messages.append({"role": "user", "content": prompt + dax_input})
174
- with st.chat_message("user"):
175
- st.markdown(prompt)
176
 
177
- with st.chat_message("assistant"):
178
- message_placeholder = st.empty()
179
- full_response = ""
180
- for response in openai.ChatCompletion.create(
181
- messages=[
182
- {"role": m["role"], "content": m["content"]}
183
- for m in st.session_state.messages
184
- ],
185
- stream=True,
186
- engine="gpt-4",
187
- ):
188
- full_response += response.choices[0].delta.get("content", "")
189
- message_placeholder.markdown(full_response + "▌")
190
- message_placeholder.markdown(full_response)
191
- st.session_state.messages.append({"role": "assistant", "content": full_response})
 
3
  import os
4
  import json
5
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import requests
7
+ from html import escape
8
+ from random import choice
9
+ import plotly.graph_objects as go
10
+ import networkx as nx
11
+ import os
12
+ import openai
13
+ from langchain.agents import create_json_agent, AgentExecutor
14
+ from langchain.agents.agent_toolkits import JsonToolkit
15
+ from langchain.chains import LLMChain
16
+ from langchain.chat_models import AzureChatOpenAI
17
+ from langchain.requests import TextRequestsWrapper
18
+ from langchain.tools.json.tool import JsonSpec
19
 
20
+ openai.api_key = os.environ["OPENAI_API_KEY"]
21
+ openai.api_version = os.environ["OPENAI_API_VERSION"]
22
+ openai.api_base = os.environ["OPENAI_API_BASE"]
23
+ os.environ["OPENAI_API_VERSION"] = openai.api_version = "2023-03-15-preview"
24
+ st.set_page_config(layout="wide")
25
 
26
  def analyze_healthcare_text(text):
27
  # Endpoint, headers and subscription key
 
47
 
48
  # Making the initial POST request
49
  response = requests.post(f"{base_url}?api-version=2022-10-01-preview", headers=headers, json=data)
 
 
50
 
51
  # Get the operation-location from the response header
52
  operation_location = response.headers.get('operation-location')
 
54
  # Extract JOB-ID from the operation-location
55
  job_id = operation_location.split('/')[-1].split('?')[0]
56
 
57
+
58
  # Make a subsequent GET request to retrieve the results using the JOB-ID
59
+ success = False
60
+ while not success:
61
+ result_response = requests.get(f"{base_url}/{job_id}?api-version=2022-10-01-preview", headers=headers)
62
+ if result_response.json()['status'] == 'succeeded':
63
+ success = True
64
+ else:
65
+ time.sleep(1)
66
 
67
  # Return the JSON response from the GET request
68
  result = result_response.json()
69
 
70
  return result
71
 
72
+ def annotate_text_with_entities(original_text, entities_data):
73
+ # Color palette for different categories
74
+ PALETTE = [
75
+ "#ff4b4b",
76
+ "#ffa421",
77
+ "#ffe312",
78
+ "#21c354",
79
+ "#00d4b1",
80
+ "#00c0f2",
81
+ "#1c83e1",
82
+ "#803df5",
83
+ "#808495",
84
+ ]
85
 
86
+ # Opacities
87
+ OPACITIES = [
88
+ "33", "66",
89
+ ]
90
 
91
+ json_data = entities_data
92
+
93
+ # Extract entities from the JSON data
94
+ entities = json_data['documents'][0]['entities']
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Get unique categories from entities
97
+ unique_categories = list(set(entity['category'] for entity in entities))
 
98
 
99
+ # Create a mapping of categories to colors
100
+ category_to_color = {category: PALETTE[i % len(PALETTE)] for i, category in enumerate(unique_categories)}
101
+
102
+ # If we run out of colors in the palette, we will assign random colors to the remaining categories
103
+ if len(unique_categories) > len(PALETTE):
104
+ additional_colors = ['#'+''.join([choice('0123456789ABCDEF') for _ in range(6)]) for _ in range(len(unique_categories) - len(PALETTE))]
105
+ for i, category in enumerate(unique_categories[len(PALETTE):]):
106
+ category_to_color[category] = additional_colors[i]
107
+
108
+ def create_entity_html(entity, entity_id):
109
+ # Get the color for the entity category
110
+ color = category_to_color[entity["category"]] + OPACITIES[entity["offset"] % len(OPACITIES)]
111
+
112
+ entity_html = f'<span id="entity-{entity_id}"><span style="display: inline-flex; flex-direction: row; align-items: center; background: {color}; border-radius: 0.5rem; padding: 0.25rem 0.5rem; overflow: hidden; line-height: 1;">{entity_id}. {escape(entity["text"])}'
113
+
114
+ # If there are links, create a dropdown menu with the links
115
+ if entity.get("links"):
116
+ options = "".join(f'<option value="{link["id"]}">{link["dataSource"]} Code {link["id"]}</option>' for link in entity["links"])
117
+ dropdown_html = f'''
118
+ <span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span>
119
+ <span style="margin-left: 0.5rem; display: flex; flex-direction: column; align-items: flex-start;">
120
+ <select style="font-size: 0.75rem; opacity: 0.5;">
121
+ {options}
122
+ </select>
123
+ <label style="font-size: 0.6rem; margin-top: 0.25rem;">{entity["category"]}</label>
124
+ </span>
125
+ '''
126
+ entity_html += dropdown_html
127
+ else:
128
+ # If there are no links, just display the category label
129
+ entity_html += f'<span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span><span style="margin-left: 0.5rem; font-size: 0.75rem; opacity: 0.5;">{entity["category"]}</span>'
130
 
131
+ # Close the main span element
132
+ entity_html += '</span></span>'
133
+
134
+ return entity_html
135
+
136
+ # Create HTML representation for each entity
137
+ entity_htmls = [create_entity_html(entity, i) for i, entity in enumerate(entities)]
138
+
139
+ # Replace entities in the original text with their HTML representations
140
+ # We iterate from the end to avoid changing the offsets of the yet-to-be-replaced entities
141
+ for entity, entity_html in sorted(zip(entities, entity_htmls), key=lambda x: x[0]['offset'], reverse=True):
142
+ start = entity['offset']
143
+ end = start + entity['length']
144
+ original_text = original_text[:start] + entity_html + original_text[end:]
145
+
146
+ # Create a color key section
147
+ color_key_section = "<br><br><strong>Color Key:</strong><br>"
148
+ for category, color in category_to_color.items():
149
+ color_key_section += f'<span style="display: inline-block; background: {color}; width: 1em; height: 1em; margin-right: 0.5em; vertical-align: middle;"></span>{category}<br>'
150
+
151
+ original_text += color_key_section
152
+
153
+ return original_text, category_to_color
154
 
155
+ def create_interactive_graph_from_json(json_data, category_to_color):
156
+ # Load the JSON data
157
+ entities = json_data['documents'][0]['entities']
158
+ relations = json_data['documents'][0].get('relations', [])
159
+
160
+ # Create a new directed graph
161
+ graph = nx.DiGraph()
162
+
163
+ # Add nodes to the graph
164
+ for i, entity in enumerate(entities):
165
+ graph.add_node(i, label=entity['text'], category=entity['category'])
166
+
167
+ # Add edges to the graph
168
+ for relation in relations:
169
+ source_index = int(relation['entities'][0]['ref'].split('/')[-1])
170
+ target_index = int(relation['entities'][1]['ref'].split('/')[-1])
171
+ graph.add_edge(source_index, target_index, label=relation['relationType'])
172
+
173
+ # Get positions of the nodes using spring layout
174
+ pos = nx.spring_layout(graph)
175
+
176
+ # Get node positions
177
+ x_nodes = [pos[i][0] for i in graph.nodes]
178
+ y_nodes = [pos[i][1] for i in graph.nodes]
179
+
180
+ # Get the colors for each node based on its category
181
+ node_colors = [category_to_color[graph.nodes[i]['category']] for i in graph.nodes]
182
 
183
+ # Get edge positions
184
+ x_edges = []
185
+ y_edges = []
186
+ for edge in graph.edges:
187
+ x_edges += [pos[edge[0]][0], pos[edge[1]][0], None]
188
+ y_edges += [pos[edge[0]][1], pos[edge[1]][1], None]
189
 
190
+ # Create edge traces
191
+ edge_trace = go.Scatter(x=x_edges, y=y_edges, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')
192
+
193
+ # Create node traces with all nodes
194
+ node_trace_all = go.Scatter(x=x_nodes, y=y_nodes, text=[graph.nodes[i]['label'] for i in graph.nodes],
195
+ mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
196
+
197
+ # Create node traces with nodes having at least one edge
198
+ nodes_with_edges = set([edge[0] for edge in graph.edges] + [edge[1] for edge in graph.edges])
199
+ x_nodes_with_edges = [pos[i][0] for i in nodes_with_edges]
200
+ y_nodes_with_edges = [pos[i][1] for i in nodes_with_edges]
201
+
202
+ node_trace_with_edges = go.Scatter(x=x_nodes_with_edges, y=y_nodes_with_edges,
203
+ text=[graph.nodes[i]['label'] for i in nodes_with_edges],
204
+ mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
205
+
206
+ # Create figure
207
+ fig = go.Figure(data=[edge_trace, node_trace_all, node_trace_with_edges],
208
+ layout=go.Layout(title='Entities and Relationships in Patient Notes',
209
+ titlefont_size=16,
210
+ showlegend=False,
211
+ hovermode='closest',
212
+ margin=dict(b=20, l=5, r=5, t=40),
213
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
214
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
215
+ updatemenus=[dict(type="buttons",
216
+ x=1.15,
217
+ y=1.2,
218
+ buttons=[dict(label="All Entities",
219
+ method="update",
220
+ args=[{"visible": [True, True, False]}]),
221
+ dict(label="Entities with Relationships",
222
+ method="update",
223
+ args=[{"visible": [True, False, True]}])])]))
224
+
225
+ # Display the interactive plot
226
+ return fig
227
 
228
+ # divide the page into 3 columns
229
+ col1, col2, col3 = st.columns([2,5,2])
230
 
231
+ if 'r' not in st.session_state:
232
+ st.session_state.r = 'value'
233
 
234
+ if 'r_annotated' not in st.session_state:
235
+ st.session_state.r_annotated = 'value'
236
 
237
+ if 'colour_to_category' not in st.session_state:
238
+ st.session_state.colour_to_category = 'value'
 
239
 
240
+ with col1:
241
+ col1.subheader("DAX Express Input")
 
 
242
 
243
+ st.text("Enter your DAX Express output below:")
244
 
245
+ dax_input = st.text_area("", height=500)
246
+ analyze_btn = st.button("Analyze")
247
+
248
+ with col2:
249
+ col2.subheader("Text Analytics for Health Output")
250
+ if st.session_state.r_annotated != 'value':
251
+ with st.expander("Entity Mappings"):
252
+ st.markdown(st.session_state.r_annotated, unsafe_allow_html=True)
253
+ with st.expander("Show Relationships"):
254
+ st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, st.session_state.colour_to_category), use_container_width=True)
255
+ with st.expander("Show JSON"):
256
+ st.json(st.session_state.r)
257
+
258
+ if analyze_btn:
259
+
260
+ st.session_state.r = analyze_healthcare_text(dax_input)["tasks"]["items"][0]["results"]
261
+ r_annotated, category_to_color = annotate_text_with_entities(dax_input, st.session_state.r)
262
+ st.session_state.r_annotated = r_annotated
263
+ st.session_state.colour_to_category = category_to_color
264
+ with st.expander("Entity Mappings"):
265
+ st.markdown(r_annotated, unsafe_allow_html=True)
266
+ with st.expander("Show Relationships"):
267
+ st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, category_to_color), use_container_width=True)
268
+ with st.expander("Show JSON"):
269
+ st.json(st.session_state.r)
270
 
271
+
272
+
273
+
274
 
275
+ with col3:
276
+ col3.subheader("Copilot Concept")
 
277
 
278
+ question = st.text_input("Ask a question to Copilot:")
279
 
280
+ toggle = st.toggle("Grounded", False)
281
 
282
 
283
+ ask = st.button("Ask")
284
+ if toggle:
285
+ if ask:
286
+ json_spec = JsonSpec(dict_=st.session_state.r, max_value_length=7000)
287
+ json_toolkit = JsonToolkit(spec=json_spec)
288
 
289
+ json_agent_executor = create_json_agent(
290
+ llm=AzureChatOpenAI(temperature=0, deployment_name="gpt-4"), toolkit=json_toolkit, verbose=True
291
+ )
292
 
293
+ st.write(json_agent_executor.run(question))
 
 
 
294
 
295
+ elif ask:
296
+ response = openai.ChatCompletion.create(
297
+ engine="gpt-4",
298
+ messages = [{"role":"system","content":"You are an AI assistant that helps people find information."}, {"role": "user", "content" : question}],
299
+ temperature=0.7,
300
+ max_tokens=800,
301
+ top_p=0.95,
302
+ frequency_penalty=0,
303
+ presence_penalty=0,
304
+ stop=None)
305
+ st.write(response.choices[0].message.content)