Spaces:
Sleeping
Sleeping
File size: 16,026 Bytes
852ef53 fd3f788 d0690e4 852ef53 f1f7255 5f0bc7f 8f06900 332c65f cbc93d6 663ae5c fd3f788 852ef53 8f06900 852ef53 5f0bc7f 8f06900 852ef53 050a938 852ef53 332c65f 852ef53 fd3f788 852ef53 fd3f788 332c65f fd3f788 852ef53 cbc93d6 852ef53 fd3f788 cbc93d6 fd3f788 852ef53 fd3f788 852ef53 fd3f788 852ef53 fd3f788 852ef53 fd3f788 663ae5c fd3f788 852ef53 fd3f788 9551cbd fd3f788 9551cbd fd3f788 852ef53 fd3f788 852ef53 fd3f788 852ef53 fd3f788 852ef53 cbc93d6 8e8d035 cbc93d6 fd3f788 852ef53 fd3f788 852ef53 fd3f788 852ef53 fd3f788 852ef53 fd3f788 9551cbd 852ef53 9551cbd 852ef53 fd3f788 cbc93d6 fd3f788 cbc93d6 852ef53 fd3f788 28dc1ef 5f0bc7f fe22626 fd3f788 cbc93d6 28dc1ef fd3f788 28dc1ef 9551cbd cbc93d6 9551cbd 28dc1ef 852ef53 cbc93d6 fe22626 cbc93d6 fd3f788 9551cbd fd3f788 cbc93d6 fd3f788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 |
import openai
import streamlit as st
import os
import json
import time
import requests
from html import escape
from random import choice
import plotly.graph_objects as go
import networkx as nx
import os
import openai
from langchain.agents import create_json_agent, AgentExecutor
from langchain.agents.agent_toolkits import JsonToolkit
from langchain.chains import LLMChain
from langchain.chat_models import AzureChatOpenAI
from langchain.requests import TextRequestsWrapper
from langchain.tools.json.tool import JsonSpec
from langchain.indexes import VectorstoreIndexCreator
openai.api_type = "azure"
# OPENAI_API_BASE was also hardcoded here , you should replace with your endpoint.
os.environ["OPENAI_API_BASE"] = openai.api_base = "https://openaihelia.openai.azure.com/" #"https://japan-sean-aoai.openai.azure.com/"
os.environ["OPENAI_API_VERSION"] = openai.api_version = "2024-04-01-preview" #"2023-04-01"
openai.api_key = os.environ["OPENAI_API_KEY"]
st.set_page_config(layout="wide")
#Azure AI Services
def analyze_healthcare_text(text):
# Endpoint, headers and subscription key
# hardcoded , needs replacing with your endpoint but keep the path i.e - juyst replace the base url - https://ta4h-endpoint.cognitiveservices.azure.com
base_url = "https://azaiserviceshm.cognitiveservices.azure.com/language/analyze-text/jobs" #"https://ta4h-endpoint.cognitiveservices.azure.com/language/analyze-text/jobs"
headers = {
"Content-Type": "application/json",
"Ocp-Apim-Subscription-Key": os.environ["TA4HAPIKEY"]
}
# Data to be sent in the initial POST request
data = {
"tasks": [{"kind": "Healthcare"}],
"analysisInput": {
"documents": [
{
"id": "documentId",
"text": text,
"language": "en"
}
]
}
}
# Making the initial POST request
response = requests.post(f"{base_url}?api-version=2023-04-01", headers=headers, json=data) #2022-10-01-preview
# Get the operation-location from the response header
operation_location = response.headers.get('operation-location')
# Extract JOB-ID from the operation-location
job_id = operation_location.split('/')[-1].split('?')[0]
# Make a subsequent GET request to retrieve the results using the JOB-ID
success = False
while not success:
result_response = requests.get(f"{base_url}/{job_id}?api-version=2023-04-01", headers=headers) #2022-10-01-preview
if result_response.json()['status'] == 'succeeded':
success = True
else:
time.sleep(1)
# Return the JSON response from the GET request
result = result_response.json()
# save json to file
return result
def annotate_text_with_entities(original_text, entities_data):
# save json to file
with open('entities_data.json', 'w') as f:
json.dump(entities_data, f)
# Color palette for different categories
PALETTE = [
"#ff4b4b",
"#ffa421",
"#ffe312",
"#21c354",
"#00d4b1",
"#00c0f2",
"#1c83e1",
"#803df5",
"#808495",
]
# Opacities
OPACITIES = [
"33", "66",
]
json_data = entities_data
# Extract entities from the JSON data
entities = json_data['documents'][0]['entities']
# Get unique categories from entities
unique_categories = list(set(entity['category'] for entity in entities))
# Create a mapping of categories to colors
category_to_color = {category: PALETTE[i % len(PALETTE)] for i, category in enumerate(unique_categories)}
# If we run out of colors in the palette, we will assign random colors to the remaining categories
if len(unique_categories) > len(PALETTE):
additional_colors = ['#'+''.join([choice('0123456789ABCDEF') for _ in range(6)]) for _ in range(len(unique_categories) - len(PALETTE))]
for i, category in enumerate(unique_categories[len(PALETTE):]):
category_to_color[category] = additional_colors[i]
def create_entity_html(entity, entity_id):
# Get the color for the entity category
color = category_to_color[entity["category"]] + OPACITIES[entity["offset"] % len(OPACITIES)]
entity_html = f'<span id="entity-{entity_id}"><span style="display: inline-flex; flex-direction: row; align-items: center; background: {color}; border-radius: 0.5rem; padding: 0.25rem 0.5rem; overflow: hidden; line-height: 1;">{escape(entity["text"])}'
# If there are links, create a dropdown menu with the links
if entity.get("links"):
options = "".join(f'<option value="{link["id"]}">{link["dataSource"]} Code {link["id"]}</option>' for link in entity["links"])
dropdown_html = f'''
<span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span>
<span style="margin-left: 0.5rem; display: flex; flex-direction: column; align-items: flex-start;">
<select style="font-size: 0.75rem; opacity: 0.5;">
{options}
</select>
<label style="font-size: 0.6rem; margin-top: 0.25rem;">{entity["category"]}</label>
</span>
'''
entity_html += dropdown_html
else:
# If there are no links, just display the category label
entity_html += f'<span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span><span style="margin-left: 0.5rem; font-size: 0.75rem; opacity: 0.5;">{entity["category"]}</span>'
# Close the main span element
entity_html += '</span></span>'
return entity_html
# Create HTML representation for each entity
entity_htmls = [create_entity_html(entity, i) for i, entity in enumerate(entities)]
# Replace entities in the original text with their HTML representations
# We iterate from the end to avoid changing the offsets of the yet-to-be-replaced entities
for entity, entity_html in sorted(zip(entities, entity_htmls), key=lambda x: x[0]['offset'], reverse=True):
start = entity['offset']
end = start + entity['length']
original_text = original_text[:start] + entity_html + original_text[end:]
# Create a color key section
color_key_section = "<strong>Color Key:</strong><br>"
for category, color in category_to_color.items():
color_key_section += f'<span style="display: inline-block; background: {color}; width: 1em; height: 1em; margin-right: 0.5em; vertical-align: middle;"></span>{category}<br>'
original_text = color_key_section + original_text
return original_text, category_to_color
def create_interactive_graph_from_json(json_data, category_to_color):
# Load the JSON data
entities = json_data['documents'][0]['entities']
relations = json_data['documents'][0].get('relations', [])
# Create a new directed graph
graph = nx.DiGraph()
# Add nodes to the graph
for i, entity in enumerate(entities):
graph.add_node(i, label=entity['text'], category=entity['category'])
# Add edges to the graph
for relation in relations:
source_index = int(relation['entities'][0]['ref'].split('/')[-1])
target_index = int(relation['entities'][1]['ref'].split('/')[-1])
graph.add_edge(source_index, target_index, label=relation['relationType'])
# Get positions of the nodes using spring layout
pos = nx.spring_layout(graph)
# Get node positions
x_nodes = [pos[i][0] for i in graph.nodes]
y_nodes = [pos[i][1] for i in graph.nodes]
# Get the colors for each node based on its category
node_colors = [category_to_color[graph.nodes[i]['category']] for i in graph.nodes]
# Get edge positions
x_edges = []
y_edges = []
for edge in graph.edges:
x_edges += [pos[edge[0]][0], pos[edge[1]][0], None]
y_edges += [pos[edge[0]][1], pos[edge[1]][1], None]
# Create edge traces
edge_trace = go.Scatter(x=x_edges, y=y_edges, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')
# Create node traces with all nodes
node_trace_all = go.Scatter(x=x_nodes, y=y_nodes, text=[graph.nodes[i]['label'] for i in graph.nodes],
mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
# Create node traces with nodes having at least one edge
nodes_with_edges = set([edge[0] for edge in graph.edges] + [edge[1] for edge in graph.edges])
x_nodes_with_edges = [pos[i][0] for i in nodes_with_edges]
y_nodes_with_edges = [pos[i][1] for i in nodes_with_edges]
node_trace_with_edges = go.Scatter(x=x_nodes_with_edges, y=y_nodes_with_edges,
text=[graph.nodes[i]['label'] for i in nodes_with_edges],
mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
# Create figure
fig = go.Figure(data=[edge_trace, node_trace_all, node_trace_with_edges],
layout=go.Layout(title='Entities and Relationships in Patient Notes',
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
updatemenus=[dict(type="buttons",
x=1.15,
y=1.2,
buttons=[dict(label="All Entities",
method="update",
args=[{"visible": [True, True, False]}]),
dict(label="Entities with Relationships",
method="update",
args=[{"visible": [True, False, True]}])])]))
# Display the interactive plot
return fig
def format_sdoh_entities_as_list(json_data):
# print(json_data)
relevant_categories = ['EMPLOYMENT', 'LIVING_STATUS', 'SUBSTANCEUSE', 'SUBSTANCEUSEAMOUNT', 'ETHNICITY']
formatted_result = []
for document in json_data['documents']:
for entity in document['entities']:
category = entity['category'].upper()
if category in relevant_categories:
formatted_result.append(f"- **{category}** : '{entity['text']}' \n")
print(formatted_result)
return '\n'.join(formatted_result)
# divide the page into 3 columns
col1, col2, col3 = st.columns([2,5,2])
if 'r' not in st.session_state:
st.session_state.r = 'value'
if 'r_annotated' not in st.session_state:
st.session_state.r_annotated = 'value'
if 'colour_to_category' not in st.session_state:
st.session_state.colour_to_category = 'value'
with col1:
col1.subheader("Patient Note Input")
st.text("Enter your text input below:")
dax_input = st.text_area("", height=500)
analyze_btn = st.button("Analyze")
with col2:
col2.subheader("Text Analytics for Health Output")
if st.session_state.r_annotated != 'value':
with st.expander("Entity Mappings"):
st.markdown(st.session_state.r_annotated, unsafe_allow_html=True)
with st.expander("Show Relationships"):
st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, st.session_state.colour_to_category), use_container_width=True)
with st.expander("Show JSON"):
st.json(st.session_state.r)
with st.expander("Show SDOH"):
st.write(format_sdoh_entities_as_list(st.session_state.r))
if analyze_btn:
st.session_state.r = analyze_healthcare_text(dax_input)["tasks"]["items"][0]["results"]
r_annotated, category_to_color = annotate_text_with_entities(dax_input, st.session_state.r)
st.session_state.r_annotated = r_annotated
st.session_state.colour_to_category = category_to_color
with st.expander("Entity Mappings"):
st.markdown(r_annotated, unsafe_allow_html=True)
with st.expander("Show Relationships"):
st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, category_to_color), use_container_width=True)
with st.expander("Show JSON"):
st.json(st.session_state.r)
with st.expander("Show SDOH"):
st.write("Social Determinants of Health (SDOH) Entities")
st.write(format_sdoh_entities_as_list(st.session_state.r))
with col3:
col3.subheader("Copilot Concept")
question = st.text_input("Ask a question to Copilot:")
grounded = st.selectbox('Would you like to ground the model?', ('Not Grounded', 'Text Analytics for Health', 'Just Text Input', 'Both'))
# model is hardcoded - replace with whatever your deployment name is below
model = st.selectbox("Model", ["gpt-4o"])
ask = st.button("Ask")
report_btn = st.button("Generate SDOH Report")
if report_btn:
with st.spinner("Generating Report"):
response = openai.ChatCompletion.create(
engine=model,
messages = [{"role":"system","content":f"Your job is to write a brief list of potential SDOH information in the data you are given. Then give recomendations on things a doctor might consider in relation to those factors when providing care."}, {"role": "user", "content" : f"data: ``` {dax_input} ```"}],
temperature=0,
max_tokens=800,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None)
st.write(response.choices[0].message.content)
if grounded != 'Not Grounded':
if ask:
if grounded == 'Text Analytics for Health':
total_text = str(st.session_state.r)
elif grounded == 'Just Text Input':
total_text = dax_input
elif grounded == 'Both':
total_text = str(st.session_state.r) + dax_input
# print(total_text)
response = openai.ChatCompletion.create(
engine=model,
messages = [{"role":"system","content":f"You are an AI assistant that helps doctors find information in complex structured data below. Answer the question directly and very briefly using the data provided. \n Question: {question}, if you do not have information below to answer it, state that you can't answer the question with the information you have been given. "}, {"role": "user", "content" : f"data: ``` {total_text} ```"}],
temperature=0,
max_tokens=300,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None)
st.write(response.choices[0].message.content)
elif ask:
print("not grounded")
response = openai.ChatCompletion.create(
engine=model,
messages = [{"role":"system","content":"You are an AI assistant that helps people find information. Be super brief."}, {"role": "user", "content" : question}],
temperature=0.7,
max_tokens=200,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None)
st.write(response.choices[0].message.content)
|