|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
import os |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
def load_model(): |
|
|
global model, tokenizer |
|
|
if model is None: |
|
|
base_model_id = "meta-llama/Llama-3.2-3B-Instruct" |
|
|
adapter_id = "NurseCitizenDeveloper/NurseSim-Triage-Llama-3.2-3B" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=HF_TOKEN) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
load_in_4bit=True, |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(model, adapter_id, token=HF_TOKEN) |
|
|
model.eval() |
|
|
return model, tokenizer |
|
|
|
|
|
def format_prompt(complaint, hr, bp, spo2, temp, rr, avpu, age, gender, pmh): |
|
|
|
|
|
history_dict = { |
|
|
'age': int(age) if age else "Unknown", |
|
|
'gender': gender, |
|
|
'relevant_PMH': pmh if pmh else "None", |
|
|
'time_course': "See complaint" |
|
|
} |
|
|
|
|
|
|
|
|
input_text = f"""PATIENT PRESENTING TO A&E TRIAGE |
|
|
|
|
|
Chief Complaint: "{complaint}" |
|
|
|
|
|
Vitals: |
|
|
- HR: {hr} bpm |
|
|
- BP: {bp} mmHg |
|
|
- SpO2: {spo2}% |
|
|
- RR: {rr} /min |
|
|
- Temp: {temp}C |
|
|
- AVPU: {avpu} |
|
|
|
|
|
History: {history_dict} |
|
|
|
|
|
WAITING ROOM: 12 patients | AVAILABLE BEDS: 4 |
|
|
|
|
|
What is your triage decision?""" |
|
|
|
|
|
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
|
|
|
|
|
### Instruction: |
|
|
You are an expert A&E Triage Nurse using the Manchester Triage System. Assess the following patient and provide your triage decision with clinical reasoning. |
|
|
|
|
|
### Input: |
|
|
{input_text} |
|
|
|
|
|
### Response: |
|
|
""" |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def triage_patient(complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu): |
|
|
model, tokenizer = load_model() |
|
|
|
|
|
prompt = format_prompt(complaint, hr, bp, spo2, temp, rr, avpu, age, gender, pmh) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.6, |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "### Response:" in response: |
|
|
response = response.split("### Response:")[-1].strip() |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate")) as demo: |
|
|
gr.Markdown(""" |
|
|
# π₯ NurseSim AI: Emergency Triage Simulator |
|
|
**An AI agent fine-tuned for the Manchester Triage System (MTS).** |
|
|
|
|
|
> **Note:** This model is trained to be **Age-Aware**. A 72-year-old with chest pain is treated differently than a 20-year-old. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 1. Patient Demographics") |
|
|
age = gr.Number(label="Age", value=45) |
|
|
gender = gr.Radio(["Male", "Female"], label="Gender", value="Male") |
|
|
pmh = gr.Textbox(label="Medical History (PMH)", placeholder="e.g., Hypertension, Diabetes, Asthma", lines=2) |
|
|
|
|
|
gr.Markdown("### 2. Presentation") |
|
|
complaint = gr.Textbox(label="Chief Complaint", placeholder="e.g., Crushing chest pain radiating to jaw", lines=2) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 3. Vital Signs") |
|
|
with gr.Row(): |
|
|
hr = gr.Number(label="HR (bpm)", value=80) |
|
|
rr = gr.Number(label="RR (breaths/min)", value=16) |
|
|
with gr.Row(): |
|
|
bp = gr.Textbox(label="BP (mmHg)", value="120/80") |
|
|
spo2 = gr.Slider(label="SpO2 (%)", minimum=50, maximum=100, value=98) |
|
|
with gr.Row(): |
|
|
temp = gr.Number(label="Temp (C)", value=37.0) |
|
|
avpu = gr.Dropdown(["A", "V", "P", "U"], label="AVPU", value="A") |
|
|
|
|
|
submit_btn = gr.Button("π¨ Assess Patient", variant="primary", size="lg") |
|
|
|
|
|
with gr.Row(): |
|
|
output_text = gr.Textbox(label="AI Triage Assessment", lines=8, show_copy_button=True) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### β οΈ Safety Disclaimer |
|
|
This system is a **research prototype** developed for the OpenEnv Challenge. |
|
|
It is **NOT** a certified medical device and should not be used for real clinical decision-making. |
|
|
""") |
|
|
|
|
|
submit_btn.click( |
|
|
fn=triage_patient, |
|
|
inputs=[complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Crushing chest pain and nausea", 72, "Male", "HTN, High Cholesterol", 110, "90/60", 94, 24, 37.2, "A"], |
|
|
["Twisted ankle at football", 22, "Male", "None", 75, "125/85", 99, 14, 36.8, "A"], |
|
|
["Swollen tongue after peanuts", 25, "Female", "Nut Allergy", 120, "90/60", 91, 28, 37.5, "A"], |
|
|
], |
|
|
inputs=[complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu], |
|
|
label="Test Scenarios" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|