import openai
import streamlit as st
import os
import json
import time
import requests
from html import escape
from random import choice
import plotly.graph_objects as go
import networkx as nx
import os
import openai
from langchain.agents import create_json_agent, AgentExecutor
from langchain.agents.agent_toolkits import JsonToolkit
from langchain.chains import LLMChain
from langchain.chat_models import AzureChatOpenAI
from langchain.requests import TextRequestsWrapper
from langchain.tools.json.tool import JsonSpec
from langchain.indexes import VectorstoreIndexCreator
openai.api_type = "azure"
# OPENAI_API_BASE was also hardcoded here , you should replace with your endpoint.
os.environ["OPENAI_API_BASE"] = openai.api_base = "https://openaihelia.openai.azure.com/" #"https://japan-sean-aoai.openai.azure.com/"
os.environ["OPENAI_API_VERSION"] = openai.api_version = "2024-04-01-preview" #"2023-04-01"
openai.api_key = os.environ["OPENAI_API_KEY"]
st.set_page_config(layout="wide")
#Azure AI Services
def analyze_healthcare_text(text):
# Endpoint, headers and subscription key
# hardcoded , needs replacing with your endpoint but keep the path i.e - juyst replace the base url - https://ta4h-endpoint.cognitiveservices.azure.com
base_url = "https://azaiserviceshm.cognitiveservices.azure.com/language/analyze-text/jobs" #"https://ta4h-endpoint.cognitiveservices.azure.com/language/analyze-text/jobs"
headers = {
"Content-Type": "application/json",
"Ocp-Apim-Subscription-Key": os.environ["TA4HAPIKEY"]
}
# Data to be sent in the initial POST request
data = {
"tasks": [{"kind": "Healthcare"}],
"analysisInput": {
"documents": [
{
"id": "documentId",
"text": text,
"language": "en"
}
]
}
}
# Making the initial POST request
response = requests.post(f"{base_url}?api-version=2023-04-01", headers=headers, json=data) #2022-10-01-preview
# Get the operation-location from the response header
operation_location = response.headers.get('operation-location')
# Extract JOB-ID from the operation-location
job_id = operation_location.split('/')[-1].split('?')[0]
# Make a subsequent GET request to retrieve the results using the JOB-ID
success = False
while not success:
result_response = requests.get(f"{base_url}/{job_id}?api-version=2023-04-01", headers=headers) #2022-10-01-preview
if result_response.json()['status'] == 'succeeded':
success = True
else:
time.sleep(1)
# Return the JSON response from the GET request
result = result_response.json()
# save json to file
return result
def annotate_text_with_entities(original_text, entities_data):
# save json to file
with open('entities_data.json', 'w') as f:
json.dump(entities_data, f)
# Color palette for different categories
PALETTE = [
"#ff4b4b",
"#ffa421",
"#ffe312",
"#21c354",
"#00d4b1",
"#00c0f2",
"#1c83e1",
"#803df5",
"#808495",
]
# Opacities
OPACITIES = [
"33", "66",
]
json_data = entities_data
# Extract entities from the JSON data
entities = json_data['documents'][0]['entities']
# Get unique categories from entities
unique_categories = list(set(entity['category'] for entity in entities))
# Create a mapping of categories to colors
category_to_color = {category: PALETTE[i % len(PALETTE)] for i, category in enumerate(unique_categories)}
# If we run out of colors in the palette, we will assign random colors to the remaining categories
if len(unique_categories) > len(PALETTE):
additional_colors = ['#'+''.join([choice('0123456789ABCDEF') for _ in range(6)]) for _ in range(len(unique_categories) - len(PALETTE))]
for i, category in enumerate(unique_categories[len(PALETTE):]):
category_to_color[category] = additional_colors[i]
def create_entity_html(entity, entity_id):
# Get the color for the entity category
color = category_to_color[entity["category"]] + OPACITIES[entity["offset"] % len(OPACITIES)]
entity_html = f'{escape(entity["text"])}'
# If there are links, create a dropdown menu with the links
if entity.get("links"):
options = "".join(f'' for link in entity["links"])
dropdown_html = f'''
'''
entity_html += dropdown_html
else:
# If there are no links, just display the category label
entity_html += f'{entity["category"]}'
# Close the main span element
entity_html += ''
return entity_html
# Create HTML representation for each entity
entity_htmls = [create_entity_html(entity, i) for i, entity in enumerate(entities)]
# Replace entities in the original text with their HTML representations
# We iterate from the end to avoid changing the offsets of the yet-to-be-replaced entities
for entity, entity_html in sorted(zip(entities, entity_htmls), key=lambda x: x[0]['offset'], reverse=True):
start = entity['offset']
end = start + entity['length']
original_text = original_text[:start] + entity_html + original_text[end:]
# Create a color key section
color_key_section = "Color Key: "
for category, color in category_to_color.items():
color_key_section += f'{category} '
original_text = color_key_section + original_text
return original_text, category_to_color
def create_interactive_graph_from_json(json_data, category_to_color):
# Load the JSON data
entities = json_data['documents'][0]['entities']
relations = json_data['documents'][0].get('relations', [])
# Create a new directed graph
graph = nx.DiGraph()
# Add nodes to the graph
for i, entity in enumerate(entities):
graph.add_node(i, label=entity['text'], category=entity['category'])
# Add edges to the graph
for relation in relations:
source_index = int(relation['entities'][0]['ref'].split('/')[-1])
target_index = int(relation['entities'][1]['ref'].split('/')[-1])
graph.add_edge(source_index, target_index, label=relation['relationType'])
# Get positions of the nodes using spring layout
pos = nx.spring_layout(graph)
# Get node positions
x_nodes = [pos[i][0] for i in graph.nodes]
y_nodes = [pos[i][1] for i in graph.nodes]
# Get the colors for each node based on its category
node_colors = [category_to_color[graph.nodes[i]['category']] for i in graph.nodes]
# Get edge positions
x_edges = []
y_edges = []
for edge in graph.edges:
x_edges += [pos[edge[0]][0], pos[edge[1]][0], None]
y_edges += [pos[edge[0]][1], pos[edge[1]][1], None]
# Create edge traces
edge_trace = go.Scatter(x=x_edges, y=y_edges, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')
# Create node traces with all nodes
node_trace_all = go.Scatter(x=x_nodes, y=y_nodes, text=[graph.nodes[i]['label'] for i in graph.nodes],
mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
# Create node traces with nodes having at least one edge
nodes_with_edges = set([edge[0] for edge in graph.edges] + [edge[1] for edge in graph.edges])
x_nodes_with_edges = [pos[i][0] for i in nodes_with_edges]
y_nodes_with_edges = [pos[i][1] for i in nodes_with_edges]
node_trace_with_edges = go.Scatter(x=x_nodes_with_edges, y=y_nodes_with_edges,
text=[graph.nodes[i]['label'] for i in nodes_with_edges],
mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
# Create figure
fig = go.Figure(data=[edge_trace, node_trace_all, node_trace_with_edges],
layout=go.Layout(title='Entities and Relationships in Patient Notes',
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
updatemenus=[dict(type="buttons",
x=1.15,
y=1.2,
buttons=[dict(label="All Entities",
method="update",
args=[{"visible": [True, True, False]}]),
dict(label="Entities with Relationships",
method="update",
args=[{"visible": [True, False, True]}])])]))
# Display the interactive plot
return fig
def format_sdoh_entities_as_list(json_data):
# print(json_data)
relevant_categories = ['EMPLOYMENT', 'LIVING_STATUS', 'SUBSTANCEUSE', 'SUBSTANCEUSEAMOUNT', 'ETHNICITY']
formatted_result = []
for document in json_data['documents']:
for entity in document['entities']:
category = entity['category'].upper()
if category in relevant_categories:
formatted_result.append(f"- **{category}** : '{entity['text']}' \n")
print(formatted_result)
return '\n'.join(formatted_result)
# divide the page into 3 columns
col1, col2, col3 = st.columns([2,5,2])
if 'r' not in st.session_state:
st.session_state.r = 'value'
if 'r_annotated' not in st.session_state:
st.session_state.r_annotated = 'value'
if 'colour_to_category' not in st.session_state:
st.session_state.colour_to_category = 'value'
with col1:
col1.subheader("Patient Note Input")
st.text("Enter your text input below:")
dax_input = st.text_area("", height=500)
analyze_btn = st.button("Analyze")
with col2:
col2.subheader("Text Analytics for Health Output")
if st.session_state.r_annotated != 'value':
with st.expander("Entity Mappings"):
st.markdown(st.session_state.r_annotated, unsafe_allow_html=True)
with st.expander("Show Relationships"):
st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, st.session_state.colour_to_category), use_container_width=True)
with st.expander("Show JSON"):
st.json(st.session_state.r)
with st.expander("Show SDOH"):
st.write(format_sdoh_entities_as_list(st.session_state.r))
if analyze_btn:
st.session_state.r = analyze_healthcare_text(dax_input)["tasks"]["items"][0]["results"]
r_annotated, category_to_color = annotate_text_with_entities(dax_input, st.session_state.r)
st.session_state.r_annotated = r_annotated
st.session_state.colour_to_category = category_to_color
with st.expander("Entity Mappings"):
st.markdown(r_annotated, unsafe_allow_html=True)
with st.expander("Show Relationships"):
st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, category_to_color), use_container_width=True)
with st.expander("Show JSON"):
st.json(st.session_state.r)
with st.expander("Show SDOH"):
st.write("Social Determinants of Health (SDOH) Entities")
st.write(format_sdoh_entities_as_list(st.session_state.r))
with col3:
col3.subheader("Copilot Concept")
question = st.text_input("Ask a question to Copilot:")
grounded = st.selectbox('Would you like to ground the model?', ('Not Grounded', 'Text Analytics for Health', 'Just Text Input', 'Both'))
# model is hardcoded - replace with whatever your deployment name is below
model = st.selectbox("Model", ["gpt-4o"])
ask = st.button("Ask")
report_btn = st.button("Generate SDOH Report")
if report_btn:
with st.spinner("Generating Report"):
response = openai.ChatCompletion.create(
engine=model,
messages = [{"role":"system","content":f"Your job is to write a brief list of potential SDOH information in the data you are given. Then give recomendations on things a doctor might consider in relation to those factors when providing care."}, {"role": "user", "content" : f"data: ``` {dax_input} ```"}],
temperature=0,
max_tokens=800,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None)
st.write(response.choices[0].message.content)
if grounded != 'Not Grounded':
if ask:
if grounded == 'Text Analytics for Health':
total_text = str(st.session_state.r)
elif grounded == 'Just Text Input':
total_text = dax_input
elif grounded == 'Both':
total_text = str(st.session_state.r) + dax_input
# print(total_text)
response = openai.ChatCompletion.create(
engine=model,
messages = [{"role":"system","content":f"You are an AI assistant that helps doctors find information in complex structured data below. Answer the question directly and very briefly using the data provided. \n Question: {question}, if you do not have information below to answer it, state that you can't answer the question with the information you have been given. "}, {"role": "user", "content" : f"data: ``` {total_text} ```"}],
temperature=0,
max_tokens=300,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None)
st.write(response.choices[0].message.content)
elif ask:
print("not grounded")
response = openai.ChatCompletion.create(
engine=model,
messages = [{"role":"system","content":"You are an AI assistant that helps people find information. Be super brief."}, {"role": "user", "content" : question}],
temperature=0.7,
max_tokens=200,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None)
st.write(response.choices[0].message.content)