import openai import streamlit as st import os import json import time import requests from html import escape from random import choice import plotly.graph_objects as go import networkx as nx import os import openai from langchain.agents import create_json_agent, AgentExecutor from langchain.agents.agent_toolkits import JsonToolkit from langchain.chains import LLMChain from langchain.chat_models import AzureChatOpenAI from langchain.requests import TextRequestsWrapper from langchain.tools.json.tool import JsonSpec from langchain.indexes import VectorstoreIndexCreator openai.api_type = "azure" # OPENAI_API_BASE was also hardcoded here , you should replace with your endpoint. os.environ["OPENAI_API_BASE"] = openai.api_base = "https://openaihelia.openai.azure.com/" #"https://japan-sean-aoai.openai.azure.com/" os.environ["OPENAI_API_VERSION"] = openai.api_version = "2024-04-01-preview" #"2023-04-01" openai.api_key = os.environ["OPENAI_API_KEY"] st.set_page_config(layout="wide") #Azure AI Services def analyze_healthcare_text(text): # Endpoint, headers and subscription key # hardcoded , needs replacing with your endpoint but keep the path i.e - juyst replace the base url - https://ta4h-endpoint.cognitiveservices.azure.com base_url = "https://azaiserviceshm.cognitiveservices.azure.com/language/analyze-text/jobs" #"https://ta4h-endpoint.cognitiveservices.azure.com/language/analyze-text/jobs" headers = { "Content-Type": "application/json", "Ocp-Apim-Subscription-Key": os.environ["TA4HAPIKEY"] } # Data to be sent in the initial POST request data = { "tasks": [{"kind": "Healthcare"}], "analysisInput": { "documents": [ { "id": "documentId", "text": text, "language": "en" } ] } } # Making the initial POST request response = requests.post(f"{base_url}?api-version=2023-04-01", headers=headers, json=data) #2022-10-01-preview # Get the operation-location from the response header operation_location = response.headers.get('operation-location') # Extract JOB-ID from the operation-location job_id = operation_location.split('/')[-1].split('?')[0] # Make a subsequent GET request to retrieve the results using the JOB-ID success = False while not success: result_response = requests.get(f"{base_url}/{job_id}?api-version=2023-04-01", headers=headers) #2022-10-01-preview if result_response.json()['status'] == 'succeeded': success = True else: time.sleep(1) # Return the JSON response from the GET request result = result_response.json() # save json to file return result def annotate_text_with_entities(original_text, entities_data): # save json to file with open('entities_data.json', 'w') as f: json.dump(entities_data, f) # Color palette for different categories PALETTE = [ "#ff4b4b", "#ffa421", "#ffe312", "#21c354", "#00d4b1", "#00c0f2", "#1c83e1", "#803df5", "#808495", ] # Opacities OPACITIES = [ "33", "66", ] json_data = entities_data # Extract entities from the JSON data entities = json_data['documents'][0]['entities'] # Get unique categories from entities unique_categories = list(set(entity['category'] for entity in entities)) # Create a mapping of categories to colors category_to_color = {category: PALETTE[i % len(PALETTE)] for i, category in enumerate(unique_categories)} # If we run out of colors in the palette, we will assign random colors to the remaining categories if len(unique_categories) > len(PALETTE): additional_colors = ['#'+''.join([choice('0123456789ABCDEF') for _ in range(6)]) for _ in range(len(unique_categories) - len(PALETTE))] for i, category in enumerate(unique_categories[len(PALETTE):]): category_to_color[category] = additional_colors[i] def create_entity_html(entity, entity_id): # Get the color for the entity category color = category_to_color[entity["category"]] + OPACITIES[entity["offset"] % len(OPACITIES)] entity_html = f'{escape(entity["text"])}' # If there are links, create a dropdown menu with the links if entity.get("links"): options = "".join(f'' for link in entity["links"]) dropdown_html = f''' ''' entity_html += dropdown_html else: # If there are no links, just display the category label entity_html += f'{entity["category"]}' # Close the main span element entity_html += '' return entity_html # Create HTML representation for each entity entity_htmls = [create_entity_html(entity, i) for i, entity in enumerate(entities)] # Replace entities in the original text with their HTML representations # We iterate from the end to avoid changing the offsets of the yet-to-be-replaced entities for entity, entity_html in sorted(zip(entities, entity_htmls), key=lambda x: x[0]['offset'], reverse=True): start = entity['offset'] end = start + entity['length'] original_text = original_text[:start] + entity_html + original_text[end:] # Create a color key section color_key_section = "Color Key:
" for category, color in category_to_color.items(): color_key_section += f'{category}
' original_text = color_key_section + original_text return original_text, category_to_color def create_interactive_graph_from_json(json_data, category_to_color): # Load the JSON data entities = json_data['documents'][0]['entities'] relations = json_data['documents'][0].get('relations', []) # Create a new directed graph graph = nx.DiGraph() # Add nodes to the graph for i, entity in enumerate(entities): graph.add_node(i, label=entity['text'], category=entity['category']) # Add edges to the graph for relation in relations: source_index = int(relation['entities'][0]['ref'].split('/')[-1]) target_index = int(relation['entities'][1]['ref'].split('/')[-1]) graph.add_edge(source_index, target_index, label=relation['relationType']) # Get positions of the nodes using spring layout pos = nx.spring_layout(graph) # Get node positions x_nodes = [pos[i][0] for i in graph.nodes] y_nodes = [pos[i][1] for i in graph.nodes] # Get the colors for each node based on its category node_colors = [category_to_color[graph.nodes[i]['category']] for i in graph.nodes] # Get edge positions x_edges = [] y_edges = [] for edge in graph.edges: x_edges += [pos[edge[0]][0], pos[edge[1]][0], None] y_edges += [pos[edge[0]][1], pos[edge[1]][1], None] # Create edge traces edge_trace = go.Scatter(x=x_edges, y=y_edges, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines') # Create node traces with all nodes node_trace_all = go.Scatter(x=x_nodes, y=y_nodes, text=[graph.nodes[i]['label'] for i in graph.nodes], mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10)) # Create node traces with nodes having at least one edge nodes_with_edges = set([edge[0] for edge in graph.edges] + [edge[1] for edge in graph.edges]) x_nodes_with_edges = [pos[i][0] for i in nodes_with_edges] y_nodes_with_edges = [pos[i][1] for i in nodes_with_edges] node_trace_with_edges = go.Scatter(x=x_nodes_with_edges, y=y_nodes_with_edges, text=[graph.nodes[i]['label'] for i in nodes_with_edges], mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10)) # Create figure fig = go.Figure(data=[edge_trace, node_trace_all, node_trace_with_edges], layout=go.Layout(title='Entities and Relationships in Patient Notes', titlefont_size=16, showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), updatemenus=[dict(type="buttons", x=1.15, y=1.2, buttons=[dict(label="All Entities", method="update", args=[{"visible": [True, True, False]}]), dict(label="Entities with Relationships", method="update", args=[{"visible": [True, False, True]}])])])) # Display the interactive plot return fig def format_sdoh_entities_as_list(json_data): # print(json_data) relevant_categories = ['EMPLOYMENT', 'LIVING_STATUS', 'SUBSTANCEUSE', 'SUBSTANCEUSEAMOUNT', 'ETHNICITY'] formatted_result = [] for document in json_data['documents']: for entity in document['entities']: category = entity['category'].upper() if category in relevant_categories: formatted_result.append(f"- **{category}** : '{entity['text']}' \n") print(formatted_result) return '\n'.join(formatted_result) # divide the page into 3 columns col1, col2, col3 = st.columns([2,5,2]) if 'r' not in st.session_state: st.session_state.r = 'value' if 'r_annotated' not in st.session_state: st.session_state.r_annotated = 'value' if 'colour_to_category' not in st.session_state: st.session_state.colour_to_category = 'value' with col1: col1.subheader("Patient Note Input") st.text("Enter your text input below:") dax_input = st.text_area("", height=500) analyze_btn = st.button("Analyze") with col2: col2.subheader("Text Analytics for Health Output") if st.session_state.r_annotated != 'value': with st.expander("Entity Mappings"): st.markdown(st.session_state.r_annotated, unsafe_allow_html=True) with st.expander("Show Relationships"): st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, st.session_state.colour_to_category), use_container_width=True) with st.expander("Show JSON"): st.json(st.session_state.r) with st.expander("Show SDOH"): st.write(format_sdoh_entities_as_list(st.session_state.r)) if analyze_btn: st.session_state.r = analyze_healthcare_text(dax_input)["tasks"]["items"][0]["results"] r_annotated, category_to_color = annotate_text_with_entities(dax_input, st.session_state.r) st.session_state.r_annotated = r_annotated st.session_state.colour_to_category = category_to_color with st.expander("Entity Mappings"): st.markdown(r_annotated, unsafe_allow_html=True) with st.expander("Show Relationships"): st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, category_to_color), use_container_width=True) with st.expander("Show JSON"): st.json(st.session_state.r) with st.expander("Show SDOH"): st.write("Social Determinants of Health (SDOH) Entities") st.write(format_sdoh_entities_as_list(st.session_state.r)) with col3: col3.subheader("Copilot Concept") question = st.text_input("Ask a question to Copilot:") grounded = st.selectbox('Would you like to ground the model?', ('Not Grounded', 'Text Analytics for Health', 'Just Text Input', 'Both')) # model is hardcoded - replace with whatever your deployment name is below model = st.selectbox("Model", ["gpt-4o"]) ask = st.button("Ask") report_btn = st.button("Generate SDOH Report") if report_btn: with st.spinner("Generating Report"): response = openai.ChatCompletion.create( engine=model, messages = [{"role":"system","content":f"Your job is to write a brief list of potential SDOH information in the data you are given. Then give recomendations on things a doctor might consider in relation to those factors when providing care."}, {"role": "user", "content" : f"data: ``` {dax_input} ```"}], temperature=0, max_tokens=800, top_p=0.95, frequency_penalty=0, presence_penalty=0, stop=None) st.write(response.choices[0].message.content) if grounded != 'Not Grounded': if ask: if grounded == 'Text Analytics for Health': total_text = str(st.session_state.r) elif grounded == 'Just Text Input': total_text = dax_input elif grounded == 'Both': total_text = str(st.session_state.r) + dax_input # print(total_text) response = openai.ChatCompletion.create( engine=model, messages = [{"role":"system","content":f"You are an AI assistant that helps doctors find information in complex structured data below. Answer the question directly and very briefly using the data provided. \n Question: {question}, if you do not have information below to answer it, state that you can't answer the question with the information you have been given. "}, {"role": "user", "content" : f"data: ``` {total_text} ```"}], temperature=0, max_tokens=300, top_p=0.95, frequency_penalty=0, presence_penalty=0, stop=None) st.write(response.choices[0].message.content) elif ask: print("not grounded") response = openai.ChatCompletion.create( engine=model, messages = [{"role":"system","content":"You are an AI assistant that helps people find information. Be super brief."}, {"role": "user", "content" : question}], temperature=0.7, max_tokens=200, top_p=0.95, frequency_penalty=0, presence_penalty=0, stop=None) st.write(response.choices[0].message.content)