Spaces:
Build error
Build error
| import gradio as gr | |
| import shap | |
| import numpy as np | |
| import scipy as sp | |
| import torch | |
| import tensorflow as tf | |
| import transformers | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification | |
| import matplotlib.pyplot as plt | |
| import sys | |
| import csv | |
| import io | |
| import base64 | |
| # Increase CSV field size limit | |
| csv.field_size_limit(sys.maxsize) | |
| # Set device for model | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Load models and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("TyHamil/ADRv2025") | |
| model = AutoModelForSequenceClassification.from_pretrained("TyHamil/ADRv2025").to(device) | |
| # Prediction pipeline | |
| pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) | |
| # SHAP explainer | |
| #explainer = shap.Explainer(pred) | |
| import shap | |
| def predict_prob(texts): | |
| encoded = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**encoded) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| return probs.cpu().numpy() | |
| explainer = shap.Explainer(predict_prob, tokenizer) | |
| # NER pipeline | |
| ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
| ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
| ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") | |
| # SHAP Plotting Function | |
| def generate_shap_plot(shap_values): | |
| plt.figure() | |
| shap.plots.text(shap_values[0], show=False) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close() | |
| buf.seek(0) | |
| data = base64.b64encode(buf.getvalue()).decode('utf-8') | |
| return f'<img src="data:image/png;base64,{data}"/>' | |
| # ADR Prediction Function | |
| def adr_predict(x): | |
| text_input = str(x).lower() | |
| encoded_input = tokenizer(text_input, return_tensors='pt').to(device) | |
| output = model(**encoded_input) | |
| scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() | |
| try: | |
| shap_values = explainer([text_input]) | |
| local_plot = generate_shap_plot(shap_values) | |
| except Exception as e: | |
| print(f"SHAP explanation failed: {e}") | |
| local_plot = "<p>SHAP explanation not available.</p>" | |
| # NER Processing | |
| try: | |
| res = ner_pipe(text_input) | |
| entity_colors = { | |
| 'Severity': '#a3e635', 'Sign_symptom': '#1e3a8a', 'Medication': '#c0c0c0', | |
| 'Age': '#a3e635', 'Sex': '#a3e635', 'Diagnostic_procedure': '#c0c0c0', | |
| 'Biological_structure': '#c0c0c0' | |
| } | |
| htext = "<div style='line-height: 1.5; font-family: Poppins;'>" | |
| prev_end = 0 | |
| res = sorted(res, key=lambda x: x['start']) | |
| for entity in res: | |
| start, end = entity['start'], entity['end'] | |
| word = text_input[start:end] | |
| entity_type = entity['entity_group'] | |
| color = entity_colors.get(entity_type, '#c0c0c0') | |
| htext += f"{text_input[prev_end:start]}" | |
| htext += f"<mark style='background-color:{color}; border-radius: 4px; font-weight: bold;'>{word} ({entity_type})</mark>" | |
| prev_end = end | |
| htext += text_input[prev_end:] + "</div>" | |
| except Exception as e: | |
| print(f"NER processing failed: {e}") | |
| htext = "<p>NER processing not available.</p>" | |
| label_output = { | |
| "Severe Reaction": float(scores[1]), | |
| "Non-severe Reaction": float(scores[0]) | |
| } | |
| return label_output, local_plot, htext | |
| def main(prob1): | |
| return adr_predict(prob1) | |
| # CSS Styling with Poppins Font | |
| css_content = """ | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;600&display=swap'); | |
| body, .gradio-container { | |
| font-family: 'Poppins', sans-serif; | |
| background-color: #1e3a8a; | |
| color: #ffffff; | |
| } | |
| #vertical_line { | |
| height: 300px; | |
| width: 2px; | |
| background-color: #a3e635; | |
| margin: 0 20px; | |
| display: inline-block; | |
| } | |
| .gr-button { | |
| background-color: #a3e635; | |
| color: #1e3a8a; | |
| font-weight: bold; | |
| padding: 10px; | |
| border-radius: 5px; | |
| width: 100%; | |
| } | |
| .gr-button:hover { | |
| background-color: #c0c0c0; | |
| } | |
| .gr-header { | |
| background: linear-gradient(90deg, #1e3a8a, #64b5f6); | |
| color: #ffffff; | |
| padding: 20px; | |
| border-radius: 8px; | |
| margin-bottom: 10px; | |
| text-align: center; | |
| font-size: 26px; | |
| } | |
| .tab-item { | |
| background-color: #1e3a8a; | |
| border-radius: 5px; | |
| margin: 2px; | |
| padding: 8px; | |
| font-size: 20px; | |
| text-align: center; | |
| color: #ffffff; | |
| } | |
| .tab-item:hover { | |
| background-color: #64b5f6; | |
| } | |
| .dashboard-text { | |
| font-size: 18px; | |
| color: #ffffff; | |
| } | |
| .dashboard-header { | |
| font-size: 22px; | |
| color: #a3e635; | |
| font-weight: bold; | |
| } | |
| </style> | |
| """ | |
| # Fake Dashboard Data | |
| def fake_dashboard(): | |
| dashboard_html = """ | |
| <div style='padding: 20px; background-color: #1e3a8a; border-radius: 10px; margin: 10px;'> | |
| <h3 class='dashboard-header'>๐ก๏ธ Wearable Data</h3> | |
| <p class='dashboard-text'>โค๏ธ Heart Rate: 72 bpm</p> | |
| <p class='dashboard-text'>๐ก๏ธ Temperature: 98.6ยฐF</p> | |
| <p class='dashboard-text'>โก HRV: 55 ms</p> | |
| <hr> | |
| <p class='dashboard-text'>๐ Next dose: Ibuprofen 400 mg - 8:00 AM</p> | |
| <hr> | |
| <p class='dashboard-text'>๐ Last Symptom: Nausea - Yesterday at 6:30 PM</p> | |
| </div> | |
| """ | |
| return dashboard_html | |
| # Additional function to generate alerts | |
| def generate_alerts(): | |
| alerts_html = """ | |
| <div style='padding: 20px; background-color: #1e3a8a; border-radius: 10px; margin: 10px; color: #ffffff;'> | |
| <h3 style='color: #a3e635;'>๐ Alerts & Notifications</h3> | |
| <ul style='font-size: 18px; list-style-type: none; padding-left: 0;'> | |
| <li>๐จ <strong>5/6 10:00 AM:</strong> Missed your meds. How are you feeling?</li> | |
| <li>๐ <strong>5/6 8:00 AM:</strong> Have you taken your medication today?</li> | |
| <li>๐ <strong>5/5 9:00 AM:</strong> 7-day streak! Youโve been consistent with your meds. Keep it up!</li> | |
| <li>๐ค <strong>5/4 11:00 AM:</strong> Still feeling nauseous? Let us know how youโre doing.</li> | |
| </ul> | |
| </div> | |
| """ | |
| return alerts_html | |
| # Gradio Interface with the new Alerts Tab | |
| with gr.Blocks(title="AwareRx. Painless Input. Painless Life.") as demo: | |
| gr.HTML(css_content) | |
| gr.Markdown("<div class='gr-header'>AwareRx. Painless Input. Painless Life. ๐ช</div>") | |
| gr.Markdown("### How are you feeling, king?") | |
| gr.Markdown("#### This is NOT for medical diagnosis.") | |
| gr.Markdown("---") | |
| with gr.Tabs(): | |
| # Patient Diary Tab | |
| with gr.TabItem("Patient Diary"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prob1 = gr.Textbox(label="Describe your symptoms here:", lines=3, placeholder="E.g., I think I went too hard at happy hour. I feel nauseous.") | |
| submit_btn = gr.Button("Get Your Insight") | |
| label = gr.Label(label="Predicted Reaction Severity") | |
| local_plot = gr.HTML(label="Insight Explanation") | |
| htext = gr.HTML(label="Named Entities Identified") | |
| submit_btn.click(fn=main, inputs=[prob1], outputs=[label, local_plot, htext]) | |
| gr.Markdown("### |", elem_id="vertical_line") | |
| with gr.Column(scale=1): | |
| gr.HTML(fake_dashboard()) | |
| # Dashboard Tab | |
| with gr.TabItem("Dashboard"): | |
| gr.HTML(fake_dashboard()) | |
| # New Alerts Tab | |
| with gr.TabItem("Alerts"): | |
| gr.HTML(generate_alerts()) | |
| # Gradio Interface | |
| with gr.Blocks(title="AwareRx. Painless Input. Painless Life.") as demo: | |
| gr.HTML(css_content) | |
| gr.Markdown("<div class='gr-header'>AwareRx. Painless Input. Painless Life. ๐ช</div>") | |
| gr.Markdown("### How are you feeling, king?") | |
| gr.Markdown("#### This is NOT for medical diagnosis.") | |
| gr.Markdown("---") | |
| with gr.Tabs(): | |
| with gr.TabItem("Patient Diary"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prob1 = gr.Textbox(label="Describe your symptoms here:", lines=3, placeholder="E.g., I think I went too hard at happy hour. I feel nauseous.") | |
| submit_btn = gr.Button("Get Your Insight") | |
| label = gr.Label(label="Predicted Reaction Severity") | |
| local_plot = gr.HTML(label="Insight Explanation") | |
| htext = gr.HTML(label="Named Entities Identified") | |
| submit_btn.click(fn=main, inputs=[prob1], outputs=[label, local_plot, htext]) | |
| gr.Markdown("### |", elem_id="vertical_line") | |
| with gr.Column(scale=1): | |
| gr.HTML(fake_dashboard()) | |
| with gr.TabItem("Dashboard"): | |
| gr.HTML(fake_dashboard()) | |
| # Gradio Interface with the new Alerts Tab | |
| with gr.Blocks(title="AwareRx. Painless Input. Painless Life.") as demo: | |
| gr.HTML(css_content) | |
| gr.Markdown("<div class='gr-header'>AwareRx. Painless Input. Painless Life. ๐ช</div>") | |
| gr.Markdown("### How are you feeling, king?") | |
| gr.Markdown("#### This is NOT for medical diagnosis.") | |
| gr.Markdown("---") | |
| with gr.Tabs(): | |
| # Patient Diary Tab | |
| with gr.TabItem("Patient Diary"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prob1 = gr.Textbox(label="Describe your symptoms here:", lines=3, placeholder="E.g., I think I went too hard at happy hour. I feel nauseous.") | |
| submit_btn = gr.Button("Get Your Insight") | |
| label = gr.Label(label="Predicted Reaction Severity") | |
| local_plot = gr.HTML(label="Insight Explanation") | |
| htext = gr.HTML(label="Named Entities Identified") | |
| submit_btn.click(fn=main, inputs=[prob1], outputs=[label, local_plot, htext]) | |
| gr.Markdown("### |", elem_id="vertical_line") | |
| with gr.Column(scale=1): | |
| gr.HTML(fake_dashboard()) | |
| # Dashboard Tab | |
| with gr.TabItem("Dashboard"): | |
| gr.HTML(fake_dashboard()) | |
| # New Alerts Tab | |
| with gr.TabItem("Alerts"): | |
| gr.HTML(generate_alerts()) | |
| gr.Markdown("Data privacy is our priority. All information is securely stored following HIPAA guidelines.") | |
| demo.launch() | |