Spaces:
Sleeping
Sleeping
| import openai | |
| import streamlit as st | |
| import os | |
| import json | |
| import time | |
| import requests | |
| from html import escape | |
| from random import choice | |
| import plotly.graph_objects as go | |
| import networkx as nx | |
| import os | |
| import openai | |
| from langchain.agents import create_json_agent, AgentExecutor | |
| from langchain.agents.agent_toolkits import JsonToolkit | |
| from langchain.chains import LLMChain | |
| from langchain.chat_models import AzureChatOpenAI | |
| from langchain.requests import TextRequestsWrapper | |
| from langchain.tools.json.tool import JsonSpec | |
| from langchain.indexes import VectorstoreIndexCreator | |
| openai.api_type = "azure" | |
| # OPENAI_API_BASE was also hardcoded here , you should replace with your endpoint. | |
| os.environ["OPENAI_API_BASE"] = openai.api_base = "https://openaihelia.openai.azure.com/" #"https://japan-sean-aoai.openai.azure.com/" | |
| os.environ["OPENAI_API_VERSION"] = openai.api_version = "2024-04-01-preview" #"2023-04-01" | |
| openai.api_key = os.environ["OPENAI_API_KEY"] | |
| st.set_page_config(layout="wide") | |
| #Azure AI Services | |
| def analyze_healthcare_text(text): | |
| # Endpoint, headers and subscription key | |
| # hardcoded , needs replacing with your endpoint but keep the path i.e - juyst replace the base url - https://ta4h-endpoint.cognitiveservices.azure.com | |
| base_url = "https://azaiserviceshm.cognitiveservices.azure.com/language/analyze-text/jobs" #"https://ta4h-endpoint.cognitiveservices.azure.com/language/analyze-text/jobs" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Ocp-Apim-Subscription-Key": os.environ["TA4HAPIKEY"] | |
| } | |
| # Data to be sent in the initial POST request | |
| data = { | |
| "tasks": [{"kind": "Healthcare"}], | |
| "analysisInput": { | |
| "documents": [ | |
| { | |
| "id": "documentId", | |
| "text": text, | |
| "language": "en" | |
| } | |
| ] | |
| } | |
| } | |
| # Making the initial POST request | |
| response = requests.post(f"{base_url}?api-version=2023-04-01", headers=headers, json=data) #2022-10-01-preview | |
| # Get the operation-location from the response header | |
| operation_location = response.headers.get('operation-location') | |
| # Extract JOB-ID from the operation-location | |
| job_id = operation_location.split('/')[-1].split('?')[0] | |
| # Make a subsequent GET request to retrieve the results using the JOB-ID | |
| success = False | |
| while not success: | |
| result_response = requests.get(f"{base_url}/{job_id}?api-version=2023-04-01", headers=headers) #2022-10-01-preview | |
| if result_response.json()['status'] == 'succeeded': | |
| success = True | |
| else: | |
| time.sleep(1) | |
| # Return the JSON response from the GET request | |
| result = result_response.json() | |
| # save json to file | |
| return result | |
| def annotate_text_with_entities(original_text, entities_data): | |
| # save json to file | |
| with open('entities_data.json', 'w') as f: | |
| json.dump(entities_data, f) | |
| # Color palette for different categories | |
| PALETTE = [ | |
| "#ff4b4b", | |
| "#ffa421", | |
| "#ffe312", | |
| "#21c354", | |
| "#00d4b1", | |
| "#00c0f2", | |
| "#1c83e1", | |
| "#803df5", | |
| "#808495", | |
| ] | |
| # Opacities | |
| OPACITIES = [ | |
| "33", "66", | |
| ] | |
| json_data = entities_data | |
| # Extract entities from the JSON data | |
| entities = json_data['documents'][0]['entities'] | |
| # Get unique categories from entities | |
| unique_categories = list(set(entity['category'] for entity in entities)) | |
| # Create a mapping of categories to colors | |
| category_to_color = {category: PALETTE[i % len(PALETTE)] for i, category in enumerate(unique_categories)} | |
| # If we run out of colors in the palette, we will assign random colors to the remaining categories | |
| if len(unique_categories) > len(PALETTE): | |
| additional_colors = ['#'+''.join([choice('0123456789ABCDEF') for _ in range(6)]) for _ in range(len(unique_categories) - len(PALETTE))] | |
| for i, category in enumerate(unique_categories[len(PALETTE):]): | |
| category_to_color[category] = additional_colors[i] | |
| def create_entity_html(entity, entity_id): | |
| # Get the color for the entity category | |
| color = category_to_color[entity["category"]] + OPACITIES[entity["offset"] % len(OPACITIES)] | |
| entity_html = f'<span id="entity-{entity_id}"><span style="display: inline-flex; flex-direction: row; align-items: center; background: {color}; border-radius: 0.5rem; padding: 0.25rem 0.5rem; overflow: hidden; line-height: 1;">{escape(entity["text"])}' | |
| # If there are links, create a dropdown menu with the links | |
| if entity.get("links"): | |
| options = "".join(f'<option value="{link["id"]}">{link["dataSource"]} Code {link["id"]}</option>' for link in entity["links"]) | |
| dropdown_html = f''' | |
| <span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span> | |
| <span style="margin-left: 0.5rem; display: flex; flex-direction: column; align-items: flex-start;"> | |
| <select style="font-size: 0.75rem; opacity: 0.5;"> | |
| {options} | |
| </select> | |
| <label style="font-size: 0.6rem; margin-top: 0.25rem;">{entity["category"]}</label> | |
| </span> | |
| ''' | |
| entity_html += dropdown_html | |
| else: | |
| # If there are no links, just display the category label | |
| entity_html += f'<span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span><span style="margin-left: 0.5rem; font-size: 0.75rem; opacity: 0.5;">{entity["category"]}</span>' | |
| # Close the main span element | |
| entity_html += '</span></span>' | |
| return entity_html | |
| # Create HTML representation for each entity | |
| entity_htmls = [create_entity_html(entity, i) for i, entity in enumerate(entities)] | |
| # Replace entities in the original text with their HTML representations | |
| # We iterate from the end to avoid changing the offsets of the yet-to-be-replaced entities | |
| for entity, entity_html in sorted(zip(entities, entity_htmls), key=lambda x: x[0]['offset'], reverse=True): | |
| start = entity['offset'] | |
| end = start + entity['length'] | |
| original_text = original_text[:start] + entity_html + original_text[end:] | |
| # Create a color key section | |
| color_key_section = "<strong>Color Key:</strong><br>" | |
| for category, color in category_to_color.items(): | |
| color_key_section += f'<span style="display: inline-block; background: {color}; width: 1em; height: 1em; margin-right: 0.5em; vertical-align: middle;"></span>{category}<br>' | |
| original_text = color_key_section + original_text | |
| return original_text, category_to_color | |
| def create_interactive_graph_from_json(json_data, category_to_color): | |
| # Load the JSON data | |
| entities = json_data['documents'][0]['entities'] | |
| relations = json_data['documents'][0].get('relations', []) | |
| # Create a new directed graph | |
| graph = nx.DiGraph() | |
| # Add nodes to the graph | |
| for i, entity in enumerate(entities): | |
| graph.add_node(i, label=entity['text'], category=entity['category']) | |
| # Add edges to the graph | |
| for relation in relations: | |
| source_index = int(relation['entities'][0]['ref'].split('/')[-1]) | |
| target_index = int(relation['entities'][1]['ref'].split('/')[-1]) | |
| graph.add_edge(source_index, target_index, label=relation['relationType']) | |
| # Get positions of the nodes using spring layout | |
| pos = nx.spring_layout(graph) | |
| # Get node positions | |
| x_nodes = [pos[i][0] for i in graph.nodes] | |
| y_nodes = [pos[i][1] for i in graph.nodes] | |
| # Get the colors for each node based on its category | |
| node_colors = [category_to_color[graph.nodes[i]['category']] for i in graph.nodes] | |
| # Get edge positions | |
| x_edges = [] | |
| y_edges = [] | |
| for edge in graph.edges: | |
| x_edges += [pos[edge[0]][0], pos[edge[1]][0], None] | |
| y_edges += [pos[edge[0]][1], pos[edge[1]][1], None] | |
| # Create edge traces | |
| edge_trace = go.Scatter(x=x_edges, y=y_edges, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines') | |
| # Create node traces with all nodes | |
| node_trace_all = go.Scatter(x=x_nodes, y=y_nodes, text=[graph.nodes[i]['label'] for i in graph.nodes], | |
| mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10)) | |
| # Create node traces with nodes having at least one edge | |
| nodes_with_edges = set([edge[0] for edge in graph.edges] + [edge[1] for edge in graph.edges]) | |
| x_nodes_with_edges = [pos[i][0] for i in nodes_with_edges] | |
| y_nodes_with_edges = [pos[i][1] for i in nodes_with_edges] | |
| node_trace_with_edges = go.Scatter(x=x_nodes_with_edges, y=y_nodes_with_edges, | |
| text=[graph.nodes[i]['label'] for i in nodes_with_edges], | |
| mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10)) | |
| # Create figure | |
| fig = go.Figure(data=[edge_trace, node_trace_all, node_trace_with_edges], | |
| layout=go.Layout(title='Entities and Relationships in Patient Notes', | |
| titlefont_size=16, | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=20, l=5, r=5, t=40), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| updatemenus=[dict(type="buttons", | |
| x=1.15, | |
| y=1.2, | |
| buttons=[dict(label="All Entities", | |
| method="update", | |
| args=[{"visible": [True, True, False]}]), | |
| dict(label="Entities with Relationships", | |
| method="update", | |
| args=[{"visible": [True, False, True]}])])])) | |
| # Display the interactive plot | |
| return fig | |
| def format_sdoh_entities_as_list(json_data): | |
| # print(json_data) | |
| relevant_categories = ['EMPLOYMENT', 'LIVING_STATUS', 'SUBSTANCEUSE', 'SUBSTANCEUSEAMOUNT', 'ETHNICITY'] | |
| formatted_result = [] | |
| for document in json_data['documents']: | |
| for entity in document['entities']: | |
| category = entity['category'].upper() | |
| if category in relevant_categories: | |
| formatted_result.append(f"- **{category}** : '{entity['text']}' \n") | |
| print(formatted_result) | |
| return '\n'.join(formatted_result) | |
| # divide the page into 3 columns | |
| col1, col2, col3 = st.columns([2,5,2]) | |
| if 'r' not in st.session_state: | |
| st.session_state.r = 'value' | |
| if 'r_annotated' not in st.session_state: | |
| st.session_state.r_annotated = 'value' | |
| if 'colour_to_category' not in st.session_state: | |
| st.session_state.colour_to_category = 'value' | |
| with col1: | |
| col1.subheader("Patient Note Input") | |
| st.text("Enter your text input below:") | |
| dax_input = st.text_area("", height=500) | |
| analyze_btn = st.button("Analyze") | |
| with col2: | |
| col2.subheader("Text Analytics for Health Output") | |
| if st.session_state.r_annotated != 'value': | |
| with st.expander("Entity Mappings"): | |
| st.markdown(st.session_state.r_annotated, unsafe_allow_html=True) | |
| with st.expander("Show Relationships"): | |
| st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, st.session_state.colour_to_category), use_container_width=True) | |
| with st.expander("Show JSON"): | |
| st.json(st.session_state.r) | |
| with st.expander("Show SDOH"): | |
| st.write(format_sdoh_entities_as_list(st.session_state.r)) | |
| if analyze_btn: | |
| st.session_state.r = analyze_healthcare_text(dax_input)["tasks"]["items"][0]["results"] | |
| r_annotated, category_to_color = annotate_text_with_entities(dax_input, st.session_state.r) | |
| st.session_state.r_annotated = r_annotated | |
| st.session_state.colour_to_category = category_to_color | |
| with st.expander("Entity Mappings"): | |
| st.markdown(r_annotated, unsafe_allow_html=True) | |
| with st.expander("Show Relationships"): | |
| st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, category_to_color), use_container_width=True) | |
| with st.expander("Show JSON"): | |
| st.json(st.session_state.r) | |
| with st.expander("Show SDOH"): | |
| st.write("Social Determinants of Health (SDOH) Entities") | |
| st.write(format_sdoh_entities_as_list(st.session_state.r)) | |
| with col3: | |
| col3.subheader("Copilot Concept") | |
| question = st.text_input("Ask a question to Copilot:") | |
| grounded = st.selectbox('Would you like to ground the model?', ('Not Grounded', 'Text Analytics for Health', 'Just Text Input', 'Both')) | |
| # model is hardcoded - replace with whatever your deployment name is below | |
| model = st.selectbox("Model", ["gpt-4o"]) | |
| ask = st.button("Ask") | |
| report_btn = st.button("Generate SDOH Report") | |
| if report_btn: | |
| with st.spinner("Generating Report"): | |
| response = openai.ChatCompletion.create( | |
| engine=model, | |
| messages = [{"role":"system","content":f"Your job is to write a brief list of potential SDOH information in the data you are given. Then give recomendations on things a doctor might consider in relation to those factors when providing care."}, {"role": "user", "content" : f"data: ``` {dax_input} ```"}], | |
| temperature=0, | |
| max_tokens=800, | |
| top_p=0.95, | |
| frequency_penalty=0, | |
| presence_penalty=0, | |
| stop=None) | |
| st.write(response.choices[0].message.content) | |
| if grounded != 'Not Grounded': | |
| if ask: | |
| if grounded == 'Text Analytics for Health': | |
| total_text = str(st.session_state.r) | |
| elif grounded == 'Just Text Input': | |
| total_text = dax_input | |
| elif grounded == 'Both': | |
| total_text = str(st.session_state.r) + dax_input | |
| # print(total_text) | |
| response = openai.ChatCompletion.create( | |
| engine=model, | |
| messages = [{"role":"system","content":f"You are an AI assistant that helps doctors find information in complex structured data below. Answer the question directly and very briefly using the data provided. \n Question: {question}, if you do not have information below to answer it, state that you can't answer the question with the information you have been given. "}, {"role": "user", "content" : f"data: ``` {total_text} ```"}], | |
| temperature=0, | |
| max_tokens=300, | |
| top_p=0.95, | |
| frequency_penalty=0, | |
| presence_penalty=0, | |
| stop=None) | |
| st.write(response.choices[0].message.content) | |
| elif ask: | |
| print("not grounded") | |
| response = openai.ChatCompletion.create( | |
| engine=model, | |
| messages = [{"role":"system","content":"You are an AI assistant that helps people find information. Be super brief."}, {"role": "user", "content" : question}], | |
| temperature=0.7, | |
| max_tokens=200, | |
| top_p=0.95, | |
| frequency_penalty=0, | |
| presence_penalty=0, | |
| stop=None) | |
| st.write(response.choices[0].message.content) | |