RxAware_M4T1 / app.py
pxf4fp's picture
Update app.py (#7)
c9331c5 verified
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import tensorflow as tf
import transformers
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
import matplotlib.pyplot as plt
import sys
import csv
import io
import base64
# Increase CSV field size limit
csv.field_size_limit(sys.maxsize)
# Set device for model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("TyHamil/ADRv2025")
model = AutoModelForSequenceClassification.from_pretrained("TyHamil/ADRv2025").to(device)
# Prediction pipeline
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
# SHAP explainer
#explainer = shap.Explainer(pred)
import shap
def predict_prob(texts):
encoded = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
return probs.cpu().numpy()
explainer = shap.Explainer(predict_prob, tokenizer)
# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
# SHAP Plotting Function
def generate_shap_plot(shap_values):
plt.figure()
shap.plots.text(shap_values[0], show=False)
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
data = base64.b64encode(buf.getvalue()).decode('utf-8')
return f'<img src="data:image/png;base64,{data}"/>'
# ADR Prediction Function
def adr_predict(x):
text_input = str(x).lower()
encoded_input = tokenizer(text_input, return_tensors='pt').to(device)
output = model(**encoded_input)
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
try:
shap_values = explainer([text_input])
local_plot = generate_shap_plot(shap_values)
except Exception as e:
print(f"SHAP explanation failed: {e}")
local_plot = "<p>SHAP explanation not available.</p>"
# NER Processing
try:
res = ner_pipe(text_input)
entity_colors = {
'Severity': '#a3e635', 'Sign_symptom': '#1e3a8a', 'Medication': '#c0c0c0',
'Age': '#a3e635', 'Sex': '#a3e635', 'Diagnostic_procedure': '#c0c0c0',
'Biological_structure': '#c0c0c0'
}
htext = "<div style='line-height: 1.5; font-family: Poppins;'>"
prev_end = 0
res = sorted(res, key=lambda x: x['start'])
for entity in res:
start, end = entity['start'], entity['end']
word = text_input[start:end]
entity_type = entity['entity_group']
color = entity_colors.get(entity_type, '#c0c0c0')
htext += f"{text_input[prev_end:start]}"
htext += f"<mark style='background-color:{color}; border-radius: 4px; font-weight: bold;'>{word} ({entity_type})</mark>"
prev_end = end
htext += text_input[prev_end:] + "</div>"
except Exception as e:
print(f"NER processing failed: {e}")
htext = "<p>NER processing not available.</p>"
label_output = {
"Severe Reaction": float(scores[1]),
"Non-severe Reaction": float(scores[0])
}
return label_output, local_plot, htext
def main(prob1):
return adr_predict(prob1)
# CSS Styling with Poppins Font
css_content = """
<style>
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;600&display=swap');
body, .gradio-container {
font-family: 'Poppins', sans-serif;
background-color: #1e3a8a;
color: #ffffff;
}
#vertical_line {
height: 300px;
width: 2px;
background-color: #a3e635;
margin: 0 20px;
display: inline-block;
}
.gr-button {
background-color: #a3e635;
color: #1e3a8a;
font-weight: bold;
padding: 10px;
border-radius: 5px;
width: 100%;
}
.gr-button:hover {
background-color: #c0c0c0;
}
.gr-header {
background: linear-gradient(90deg, #1e3a8a, #64b5f6);
color: #ffffff;
padding: 20px;
border-radius: 8px;
margin-bottom: 10px;
text-align: center;
font-size: 26px;
}
.tab-item {
background-color: #1e3a8a;
border-radius: 5px;
margin: 2px;
padding: 8px;
font-size: 20px;
text-align: center;
color: #ffffff;
}
.tab-item:hover {
background-color: #64b5f6;
}
.dashboard-text {
font-size: 18px;
color: #ffffff;
}
.dashboard-header {
font-size: 22px;
color: #a3e635;
font-weight: bold;
}
</style>
"""
# Fake Dashboard Data
def fake_dashboard():
dashboard_html = """
<div style='padding: 20px; background-color: #1e3a8a; border-radius: 10px; margin: 10px;'>
<h3 class='dashboard-header'>๐ŸŒก๏ธ Wearable Data</h3>
<p class='dashboard-text'>โค๏ธ Heart Rate: 72 bpm</p>
<p class='dashboard-text'>๐ŸŒก๏ธ Temperature: 98.6ยฐF</p>
<p class='dashboard-text'>โšก HRV: 55 ms</p>
<hr>
<p class='dashboard-text'>๐Ÿ’Š Next dose: Ibuprofen 400 mg - 8:00 AM</p>
<hr>
<p class='dashboard-text'>๐Ÿ“… Last Symptom: Nausea - Yesterday at 6:30 PM</p>
</div>
"""
return dashboard_html
# Additional function to generate alerts
def generate_alerts():
alerts_html = """
<div style='padding: 20px; background-color: #1e3a8a; border-radius: 10px; margin: 10px; color: #ffffff;'>
<h3 style='color: #a3e635;'>๐Ÿ”” Alerts & Notifications</h3>
<ul style='font-size: 18px; list-style-type: none; padding-left: 0;'>
<li>๐Ÿšจ <strong>5/6 10:00 AM:</strong> Missed your meds. How are you feeling?</li>
<li>๐Ÿ’Š <strong>5/6 8:00 AM:</strong> Have you taken your medication today?</li>
<li>๐ŸŒŸ <strong>5/5 9:00 AM:</strong> 7-day streak! Youโ€™ve been consistent with your meds. Keep it up!</li>
<li>๐Ÿค” <strong>5/4 11:00 AM:</strong> Still feeling nauseous? Let us know how youโ€™re doing.</li>
</ul>
</div>
"""
return alerts_html
# Gradio Interface with the new Alerts Tab
with gr.Blocks(title="AwareRx. Painless Input. Painless Life.") as demo:
gr.HTML(css_content)
gr.Markdown("<div class='gr-header'>AwareRx. Painless Input. Painless Life. ๐Ÿช</div>")
gr.Markdown("### How are you feeling, king?")
gr.Markdown("#### This is NOT for medical diagnosis.")
gr.Markdown("---")
with gr.Tabs():
# Patient Diary Tab
with gr.TabItem("Patient Diary"):
with gr.Row():
with gr.Column(scale=2):
prob1 = gr.Textbox(label="Describe your symptoms here:", lines=3, placeholder="E.g., I think I went too hard at happy hour. I feel nauseous.")
submit_btn = gr.Button("Get Your Insight")
label = gr.Label(label="Predicted Reaction Severity")
local_plot = gr.HTML(label="Insight Explanation")
htext = gr.HTML(label="Named Entities Identified")
submit_btn.click(fn=main, inputs=[prob1], outputs=[label, local_plot, htext])
gr.Markdown("### |", elem_id="vertical_line")
with gr.Column(scale=1):
gr.HTML(fake_dashboard())
# Dashboard Tab
with gr.TabItem("Dashboard"):
gr.HTML(fake_dashboard())
# New Alerts Tab
with gr.TabItem("Alerts"):
gr.HTML(generate_alerts())
# Gradio Interface
with gr.Blocks(title="AwareRx. Painless Input. Painless Life.") as demo:
gr.HTML(css_content)
gr.Markdown("<div class='gr-header'>AwareRx. Painless Input. Painless Life. ๐Ÿช</div>")
gr.Markdown("### How are you feeling, king?")
gr.Markdown("#### This is NOT for medical diagnosis.")
gr.Markdown("---")
with gr.Tabs():
with gr.TabItem("Patient Diary"):
with gr.Row():
with gr.Column(scale=2):
prob1 = gr.Textbox(label="Describe your symptoms here:", lines=3, placeholder="E.g., I think I went too hard at happy hour. I feel nauseous.")
submit_btn = gr.Button("Get Your Insight")
label = gr.Label(label="Predicted Reaction Severity")
local_plot = gr.HTML(label="Insight Explanation")
htext = gr.HTML(label="Named Entities Identified")
submit_btn.click(fn=main, inputs=[prob1], outputs=[label, local_plot, htext])
gr.Markdown("### |", elem_id="vertical_line")
with gr.Column(scale=1):
gr.HTML(fake_dashboard())
with gr.TabItem("Dashboard"):
gr.HTML(fake_dashboard())
# Gradio Interface with the new Alerts Tab
with gr.Blocks(title="AwareRx. Painless Input. Painless Life.") as demo:
gr.HTML(css_content)
gr.Markdown("<div class='gr-header'>AwareRx. Painless Input. Painless Life. ๐Ÿช</div>")
gr.Markdown("### How are you feeling, king?")
gr.Markdown("#### This is NOT for medical diagnosis.")
gr.Markdown("---")
with gr.Tabs():
# Patient Diary Tab
with gr.TabItem("Patient Diary"):
with gr.Row():
with gr.Column(scale=2):
prob1 = gr.Textbox(label="Describe your symptoms here:", lines=3, placeholder="E.g., I think I went too hard at happy hour. I feel nauseous.")
submit_btn = gr.Button("Get Your Insight")
label = gr.Label(label="Predicted Reaction Severity")
local_plot = gr.HTML(label="Insight Explanation")
htext = gr.HTML(label="Named Entities Identified")
submit_btn.click(fn=main, inputs=[prob1], outputs=[label, local_plot, htext])
gr.Markdown("### |", elem_id="vertical_line")
with gr.Column(scale=1):
gr.HTML(fake_dashboard())
# Dashboard Tab
with gr.TabItem("Dashboard"):
gr.HTML(fake_dashboard())
# New Alerts Tab
with gr.TabItem("Alerts"):
gr.HTML(generate_alerts())
gr.Markdown("Data privacy is our priority. All information is securely stored following HIPAA guidelines.")
demo.launch()