File size: 16,026 Bytes
852ef53
 
 
 
 
 
fd3f788
 
 
 
 
 
 
 
 
 
 
 
d0690e4
852ef53
f1f7255
5f0bc7f
 
8f06900
332c65f
cbc93d6
 
663ae5c
fd3f788
852ef53
8f06900
852ef53
 
5f0bc7f
8f06900
852ef53
 
050a938
852ef53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332c65f
852ef53
 
 
 
 
 
 
fd3f788
852ef53
fd3f788
 
332c65f
fd3f788
 
 
 
852ef53
 
 
cbc93d6
 
852ef53
 
fd3f788
cbc93d6
 
 
 
fd3f788
 
 
 
 
 
 
 
 
 
 
 
852ef53
fd3f788
 
 
 
852ef53
fd3f788
 
 
 
852ef53
fd3f788
 
852ef53
fd3f788
 
 
 
 
 
 
 
 
 
 
 
 
663ae5c
 
fd3f788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852ef53
fd3f788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9551cbd
fd3f788
 
 
9551cbd
 
fd3f788
 
852ef53
fd3f788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852ef53
fd3f788
 
 
 
 
 
852ef53
fd3f788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852ef53
cbc93d6
 
 
8e8d035
cbc93d6
 
 
 
 
 
 
 
 
 
 
 
fd3f788
 
852ef53
fd3f788
 
852ef53
fd3f788
 
852ef53
fd3f788
 
852ef53
fd3f788
9551cbd
852ef53
9551cbd
852ef53
fd3f788
 
 
 
 
 
 
 
 
 
 
 
cbc93d6
 
 
fd3f788
 
 
 
 
 
 
 
 
 
 
 
 
cbc93d6
 
 
852ef53
fd3f788
 
 
28dc1ef
5f0bc7f
fe22626
fd3f788
cbc93d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28dc1ef
fd3f788
28dc1ef
9551cbd
 
cbc93d6
9551cbd
28dc1ef
852ef53
cbc93d6
 
 
 
 
fe22626
cbc93d6
 
 
 
 
 
fd3f788
9551cbd
fd3f788
cbc93d6
 
 
 
 
 
 
 
fd3f788
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import openai
import streamlit as st
import os
import json
import time
import requests
from html import escape
from random import choice
import plotly.graph_objects as go
import networkx as nx
import os
import openai
from langchain.agents import create_json_agent, AgentExecutor
from langchain.agents.agent_toolkits import JsonToolkit
from langchain.chains import LLMChain
from langchain.chat_models import AzureChatOpenAI
from langchain.requests import TextRequestsWrapper
from langchain.tools.json.tool import JsonSpec
from langchain.indexes import VectorstoreIndexCreator

openai.api_type = "azure"

# OPENAI_API_BASE was also hardcoded here , you should replace with your endpoint.
os.environ["OPENAI_API_BASE"] = openai.api_base = "https://openaihelia.openai.azure.com/" #"https://japan-sean-aoai.openai.azure.com/"
os.environ["OPENAI_API_VERSION"] = openai.api_version = "2024-04-01-preview" #"2023-04-01" 
openai.api_key = os.environ["OPENAI_API_KEY"]


st.set_page_config(layout="wide")

#Azure AI Services
def analyze_healthcare_text(text):
    # Endpoint, headers and subscription key
    # hardcoded , needs replacing with your endpoint but keep the path i.e - juyst replace the base url - https://ta4h-endpoint.cognitiveservices.azure.com
    base_url = "https://azaiserviceshm.cognitiveservices.azure.com/language/analyze-text/jobs" #"https://ta4h-endpoint.cognitiveservices.azure.com/language/analyze-text/jobs"
    headers = {
        "Content-Type": "application/json",
        "Ocp-Apim-Subscription-Key": os.environ["TA4HAPIKEY"]
    }

    # Data to be sent in the initial POST request
    data = {
        "tasks": [{"kind": "Healthcare"}],
        "analysisInput": {
            "documents": [
                {
                    "id": "documentId",
                    "text": text,
                    "language": "en"
                }
            ]
        }
    }

    # Making the initial POST request
    response = requests.post(f"{base_url}?api-version=2023-04-01", headers=headers, json=data) #2022-10-01-preview
    
    # Get the operation-location from the response header
    operation_location = response.headers.get('operation-location')
    
    # Extract JOB-ID from the operation-location
    job_id = operation_location.split('/')[-1].split('?')[0]
    
    
    # Make a subsequent GET request to retrieve the results using the JOB-ID
    success = False
    while not success:
        result_response = requests.get(f"{base_url}/{job_id}?api-version=2023-04-01", headers=headers) #2022-10-01-preview
        if result_response.json()['status'] == 'succeeded':
            success = True
        else:
            time.sleep(1)
    
    # Return the JSON response from the GET request
    result = result_response.json()
    # save json to file
  
    return result

def annotate_text_with_entities(original_text, entities_data):

    # save json to file
    with open('entities_data.json', 'w') as f:
        json.dump(entities_data, f)
        # Color palette for different categories
    PALETTE = [
        "#ff4b4b",
        "#ffa421",
        "#ffe312",
        "#21c354",
        "#00d4b1",
        "#00c0f2",
        "#1c83e1",
        "#803df5",
        "#808495",
    ]
    
    # Opacities
    OPACITIES = [
        "33", "66",
    ]
    
    json_data = entities_data
    
    # Extract entities from the JSON data
    entities = json_data['documents'][0]['entities']

    # Get unique categories from entities
    unique_categories = list(set(entity['category'] for entity in entities))

    # Create a mapping of categories to colors
    category_to_color = {category: PALETTE[i % len(PALETTE)] for i, category in enumerate(unique_categories)}
    
    # If we run out of colors in the palette, we will assign random colors to the remaining categories
    if len(unique_categories) > len(PALETTE):
        additional_colors = ['#'+''.join([choice('0123456789ABCDEF') for _ in range(6)]) for _ in range(len(unique_categories) - len(PALETTE))]
        for i, category in enumerate(unique_categories[len(PALETTE):]):
            category_to_color[category] = additional_colors[i]

    def create_entity_html(entity, entity_id):
        # Get the color for the entity category
        color = category_to_color[entity["category"]] + OPACITIES[entity["offset"] % len(OPACITIES)]
        
        entity_html = f'<span id="entity-{entity_id}"><span style="display: inline-flex; flex-direction: row; align-items: center; background: {color}; border-radius: 0.5rem; padding: 0.25rem 0.5rem; overflow: hidden; line-height: 1;">{escape(entity["text"])}'

        
        # If there are links, create a dropdown menu with the links
        if entity.get("links"):
            options = "".join(f'<option value="{link["id"]}">{link["dataSource"]} Code {link["id"]}</option>' for link in entity["links"])
            dropdown_html = f'''
            <span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span>
            <span style="margin-left: 0.5rem; display: flex; flex-direction: column; align-items: flex-start;">
                <select style="font-size: 0.75rem; opacity: 0.5;">
                    {options}
                </select>
                <label style="font-size: 0.6rem; margin-top: 0.25rem;">{entity["category"]}</label>
            </span>
            '''
            entity_html += dropdown_html
        else:
            # If there are no links, just display the category label
            entity_html += f'<span style="border-left: 1px solid; opacity: 0.1; margin-left: 0.5rem; align-self: stretch;"></span><span style="margin-left: 0.5rem; font-size: 0.75rem; opacity: 0.5;">{entity["category"]}</span>'

        # Close the main span element
        entity_html += '</span></span>'
        
        return entity_html
    
    # Create HTML representation for each entity
    entity_htmls = [create_entity_html(entity, i) for i, entity in enumerate(entities)]
    
    # Replace entities in the original text with their HTML representations
    # We iterate from the end to avoid changing the offsets of the yet-to-be-replaced entities
    for entity, entity_html in sorted(zip(entities, entity_htmls), key=lambda x: x[0]['offset'], reverse=True):
        start = entity['offset']
        end = start + entity['length']
        original_text = original_text[:start] + entity_html + original_text[end:]
    
    # Create a color key section
    color_key_section = "<strong>Color Key:</strong><br>"
    for category, color in category_to_color.items():
        color_key_section += f'<span style="display: inline-block; background: {color}; width: 1em; height: 1em; margin-right: 0.5em; vertical-align: middle;"></span>{category}<br>'
    
    
    original_text = color_key_section + original_text
    
    return original_text, category_to_color

def create_interactive_graph_from_json(json_data, category_to_color):
     # Load the JSON data
    entities = json_data['documents'][0]['entities']
    relations = json_data['documents'][0].get('relations', [])
    
    # Create a new directed graph
    graph = nx.DiGraph()
    
    # Add nodes to the graph
    for i, entity in enumerate(entities):
        graph.add_node(i, label=entity['text'], category=entity['category'])
    
    # Add edges to the graph
    for relation in relations:
        source_index = int(relation['entities'][0]['ref'].split('/')[-1])
        target_index = int(relation['entities'][1]['ref'].split('/')[-1])
        graph.add_edge(source_index, target_index, label=relation['relationType'])
    
    # Get positions of the nodes using spring layout
    pos = nx.spring_layout(graph)
    
    # Get node positions
    x_nodes = [pos[i][0] for i in graph.nodes]
    y_nodes = [pos[i][1] for i in graph.nodes]
    
    # Get the colors for each node based on its category
    node_colors = [category_to_color[graph.nodes[i]['category']] for i in graph.nodes]
    
    # Get edge positions
    x_edges = []
    y_edges = []
    for edge in graph.edges:
        x_edges += [pos[edge[0]][0], pos[edge[1]][0], None]
        y_edges += [pos[edge[0]][1], pos[edge[1]][1], None]
    
    # Create edge traces
    edge_trace = go.Scatter(x=x_edges, y=y_edges, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')
    
    # Create node traces with all nodes
    node_trace_all = go.Scatter(x=x_nodes, y=y_nodes, text=[graph.nodes[i]['label'] for i in graph.nodes], 
                                mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
    
    # Create node traces with nodes having at least one edge
    nodes_with_edges = set([edge[0] for edge in graph.edges] + [edge[1] for edge in graph.edges])
    x_nodes_with_edges = [pos[i][0] for i in nodes_with_edges]
    y_nodes_with_edges = [pos[i][1] for i in nodes_with_edges]
    
    node_trace_with_edges = go.Scatter(x=x_nodes_with_edges, y=y_nodes_with_edges, 
                                       text=[graph.nodes[i]['label'] for i in nodes_with_edges], 
                                       mode='markers+text', hoverinfo='text', marker=dict(color=node_colors, size=10))
    
    # Create figure
    fig = go.Figure(data=[edge_trace, node_trace_all, node_trace_with_edges],
                    layout=go.Layout(title='Entities and Relationships in Patient Notes',
                                     titlefont_size=16,
                                     showlegend=False,
                                     hovermode='closest',
                                     margin=dict(b=20, l=5, r=5, t=40),
                                     xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                                     yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                                     updatemenus=[dict(type="buttons",
                                                       x=1.15,
                                                       y=1.2,
                                                       buttons=[dict(label="All Entities",
                                                                     method="update",
                                                                     args=[{"visible": [True, True, False]}]),
                                                                 dict(label="Entities with Relationships",
                                                                     method="update",
                                                                     args=[{"visible": [True, False, True]}])])]))
    
    # Display the interactive plot
    return fig


def format_sdoh_entities_as_list(json_data):
    # print(json_data)
    relevant_categories = ['EMPLOYMENT', 'LIVING_STATUS', 'SUBSTANCEUSE', 'SUBSTANCEUSEAMOUNT', 'ETHNICITY']
    formatted_result = []
    

    for document in json_data['documents']:
        for entity in document['entities']:
            category = entity['category'].upper()
            if category in relevant_categories:
                formatted_result.append(f"- **{category}** : '{entity['text']}' \n")
    print(formatted_result)
    return '\n'.join(formatted_result)


# divide the page into 3 columns
col1, col2, col3 = st.columns([2,5,2])

if 'r' not in st.session_state:
    st.session_state.r = 'value'

if 'r_annotated' not in st.session_state:
    st.session_state.r_annotated = 'value'

if 'colour_to_category' not in st.session_state:
    st.session_state.colour_to_category = 'value'

with col1:
    col1.subheader("Patient Note Input")

    st.text("Enter your text input below:")

    dax_input = st.text_area("", height=500)
    analyze_btn = st.button("Analyze")

with col2:
    col2.subheader("Text Analytics for Health Output")
    if st.session_state.r_annotated != 'value':
        with st.expander("Entity Mappings"):
            st.markdown(st.session_state.r_annotated, unsafe_allow_html=True)
        with st.expander("Show Relationships"):
            st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, st.session_state.colour_to_category), use_container_width=True)
        with st.expander("Show JSON"):
            st.json(st.session_state.r)
        with st.expander("Show SDOH"):
            
            st.write(format_sdoh_entities_as_list(st.session_state.r))
        
    if analyze_btn:
        
        st.session_state.r = analyze_healthcare_text(dax_input)["tasks"]["items"][0]["results"]
        r_annotated, category_to_color = annotate_text_with_entities(dax_input, st.session_state.r)
        st.session_state.r_annotated = r_annotated
        st.session_state.colour_to_category = category_to_color
        with st.expander("Entity Mappings"):
            st.markdown(r_annotated, unsafe_allow_html=True)
        with st.expander("Show Relationships"):
            st.plotly_chart(create_interactive_graph_from_json(st.session_state.r, category_to_color), use_container_width=True)
        with st.expander("Show JSON"):
            st.json(st.session_state.r)
        with st.expander("Show SDOH"):
            st.write("Social Determinants of Health (SDOH) Entities")
            st.write(format_sdoh_entities_as_list(st.session_state.r))

with col3:
    col3.subheader("Copilot Concept")
    question = st.text_input("Ask a question to Copilot:")
    grounded = st.selectbox('Would you like to ground the model?', ('Not Grounded', 'Text Analytics for Health', 'Just Text Input', 'Both'))
    # model is hardcoded - replace with whatever your deployment name is below
    model = st.selectbox("Model", ["gpt-4o"])
    ask = st.button("Ask")
    report_btn = st.button("Generate SDOH Report")
    if report_btn:
        with st.spinner("Generating Report"):
            response = openai.ChatCompletion.create(
            engine=model,
            messages = [{"role":"system","content":f"Your job is to write a brief list of potential SDOH information in the data you are given. Then give recomendations on things a doctor might consider in relation to those factors when providing care."}, {"role": "user", "content" : f"data: ``` {dax_input} ```"}],
            temperature=0,
            max_tokens=800,
            top_p=0.95,
            frequency_penalty=0,
            presence_penalty=0,
            stop=None)
            st.write(response.choices[0].message.content)
            


    if grounded != 'Not Grounded':
        if ask:
            if grounded == 'Text Analytics for Health':
                total_text = str(st.session_state.r)
            elif grounded == 'Just Text Input':
                total_text =  dax_input
            elif grounded == 'Both':
                total_text = str(st.session_state.r) + dax_input

            # print(total_text)
            response = openai.ChatCompletion.create(
            engine=model,
            messages = [{"role":"system","content":f"You are an AI assistant that helps doctors find information in complex structured data below. Answer the question directly and very briefly using the data provided. \n Question: {question}, if you do not have information below to answer it, state that you can't answer the question with the information you have been given. "}, {"role": "user", "content" : f"data: ``` {total_text} ```"}],
            temperature=0,
            max_tokens=300,
            top_p=0.95,
            frequency_penalty=0,
            presence_penalty=0,
            stop=None)
            st.write(response.choices[0].message.content)
          
    elif ask:
        print("not grounded")
        response = openai.ChatCompletion.create(
        engine=model,
        messages = [{"role":"system","content":"You are an AI assistant that helps people find information. Be super brief."}, {"role": "user", "content" : question}],
        temperature=0.7,
        max_tokens=200,
        top_p=0.95,
        frequency_penalty=0,
        presence_penalty=0,
        stop=None)
        st.write(response.choices[0].message.content)