|
|
|
|
|
""" |
|
|
NurseSim-Triage Hybrid Agent Entry Point |
|
|
|
|
|
This module combines the A2A API (for AgentBeats) and the Gradio UI (for Human/Demo) |
|
|
into a single FastAPI application listening on port 7860. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import uvicorn |
|
|
import asyncio |
|
|
import gradio as gr |
|
|
from contextlib import asynccontextmanager |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from typing import Dict, Any |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Vitals(BaseModel): |
|
|
heart_rate: int = 80 |
|
|
blood_pressure: str = "120/80" |
|
|
spo2: int = 98 |
|
|
temperature: float = 37.0 |
|
|
|
|
|
class TaskInput(BaseModel): |
|
|
complaint: str |
|
|
vitals: Vitals |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NurseSimTriageAgent: |
|
|
""" |
|
|
Shared agent logic for both API and UI. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the triage agent placeholder.""" |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
if not self.HF_TOKEN: |
|
|
print("WARNING: HF_TOKEN not set. Model loading will fail if authentication is required.") |
|
|
|
|
|
async def load_model(self): |
|
|
"""Load the base model and LoRA adapters asynchronously.""" |
|
|
if self.model is not None: |
|
|
return |
|
|
|
|
|
try: |
|
|
print("β³ Starting model load...") |
|
|
base_model_id = "meta-llama/Llama-3.2-3B-Instruct" |
|
|
adapter_id = "NurseCitizenDeveloper/NurseSim-Triage-Llama-3.2-3B" |
|
|
|
|
|
|
|
|
await asyncio.to_thread(self._load_weights, base_model_id, adapter_id) |
|
|
|
|
|
print("β
Model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"β CRITICAL ERROR loading model: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
def _load_weights(self, base_model_id, adapter_id): |
|
|
print(f"Loading tokenizer from {adapter_id}...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=self.HF_TOKEN) |
|
|
|
|
|
print(f"Loading base model {base_model_id} with 4-bit quantization...") |
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_id, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True, |
|
|
token=self.HF_TOKEN, |
|
|
) |
|
|
|
|
|
print(f"Applying LoRA adapters from {adapter_id}...") |
|
|
self.model = PeftModel.from_pretrained(self.model, adapter_id, token=self.HF_TOKEN) |
|
|
self.model.eval() |
|
|
|
|
|
def get_response(self, complaint: str, hr: int, bp: str, spo2: int, temp: float) -> str: |
|
|
"""Shared inference logic.""" |
|
|
if self.model is None: |
|
|
return "β οΈ System is warming up. Please try again in 30 seconds." |
|
|
|
|
|
prompt = f"""### Instruction: |
|
|
You are an expert A&E Triage Nurse. Assess the following patient and provide your triage decision. |
|
|
|
|
|
### Input: |
|
|
Patient Complaint: {complaint} |
|
|
Vitals: HR {hr}, BP {bp}, SpO2 {spo2}%, Temp {temp}C. |
|
|
|
|
|
### Response:""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
if "### Response:" in response: |
|
|
response = response.split("### Response:")[-1].strip() |
|
|
|
|
|
return response |
|
|
|
|
|
def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process an API task.""" |
|
|
if self.model is None: |
|
|
return { |
|
|
"error": "ModelStillLoading", |
|
|
"message": "The agent is still warming up. Please retry in 30 seconds." |
|
|
} |
|
|
|
|
|
try: |
|
|
complaint = task.get("complaint", "") |
|
|
vitals = task.get("vitals", {}) |
|
|
response = self.get_response( |
|
|
complaint, |
|
|
vitals.get("heart_rate", 80), |
|
|
vitals.get("blood_pressure", "120/80"), |
|
|
vitals.get("spo2", 98), |
|
|
vitals.get("temperature", 37.0) |
|
|
) |
|
|
|
|
|
return { |
|
|
"triage_category": self._extract_triage_category(response), |
|
|
"assessment": response, |
|
|
"recommended_action": self._extract_recommended_action(response) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "triage_category": "Error"} |
|
|
|
|
|
def _extract_triage_category(self, response: str) -> str: |
|
|
response_lower = response.lower() |
|
|
if "immediate" in response_lower or "resuscitation" in response_lower: return "Immediate" |
|
|
elif "very urgent" in response_lower or "emergency" in response_lower: return "Very Urgent" |
|
|
elif "urgent" in response_lower: return "Urgent" |
|
|
elif "standard" in response_lower: return "Standard" |
|
|
elif "non-urgent" in response_lower or "non urgent" in response_lower: return "Non-Urgent" |
|
|
else: return "Standard" |
|
|
|
|
|
def _extract_recommended_action(self, response: str) -> str: |
|
|
if "monitor" in response.lower(): return "Monitor patient closely" |
|
|
elif "immediate" in response.lower() or "urgent" in response.lower(): return "Immediate medical attention required" |
|
|
else: return "Continue assessment and treatment as per protocol" |
|
|
|
|
|
def health_check(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"status": "healthy" if self.model is not None else "loading", |
|
|
"model_loaded": self.model is not None, |
|
|
"gpu_available": torch.cuda.is_available() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent = NurseSimTriageAgent() |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
print("π Server starting. Triggering model load task...") |
|
|
asyncio.create_task(agent.load_model()) |
|
|
yield |
|
|
print("π Server shutting down.") |
|
|
|
|
|
app = FastAPI(title="NurseSim-Triage Agent", version="1.2.0", lifespan=lifespan) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return agent.health_check() |
|
|
|
|
|
@app.get("/.well-known/agent-card.json") |
|
|
async def get_agent_card(): |
|
|
card_path = ".well-known/agent-card.json" |
|
|
if os.path.exists(card_path): |
|
|
with open(card_path, "r") as f: |
|
|
return json.load(f) |
|
|
raise HTTPException(status_code=404, detail="Agent card not found") |
|
|
|
|
|
@app.post("/process-task") |
|
|
async def process_task(task: TaskInput): |
|
|
result = agent.process_task(task.dict()) |
|
|
if "error" in result and result.get("message") == "ModelStillLoading": |
|
|
raise HTTPException(status_code=503, detail=result["message"]) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_predict(complaint, hr, bp, spo2, temp): |
|
|
return agent.get_response(complaint, hr, bp, spo2, temp) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# π©Ί NurseSim AI: Emergency Triage Simulator |
|
|
**An AI agent fine-tuned for the Manchester Triage System (MTS).** |
|
|
*Developed for the OpenEnv Challenge by NurseCitizenDeveloper.* |
|
|
|
|
|
> β‘ **Hybrid Mode**: Serving both Gradio UI and A2A API (AgentBeats) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
complaint = gr.Textbox(label="Chief Complaint", placeholder="e.g., Shortness of breath...") |
|
|
with gr.Row(): |
|
|
hr = gr.Number(label="Heart Rate", value=80) |
|
|
bp = gr.Textbox(label="Blood Pressure", placeholder="e.g., 120/80") |
|
|
with gr.Row(): |
|
|
spo2 = gr.Slider(label="SpO2 (%)", minimum=50, maximum=100, value=98) |
|
|
temp = gr.Number(label="Temperature (C)", value=37.0) |
|
|
|
|
|
submit_btn = gr.Button("Assess Patient", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="AI Triage Assessment", lines=10) |
|
|
gr.Markdown("### β οΈ Research Prototype - Not for Clinical Use") |
|
|
|
|
|
submit_btn.click(gradio_predict, inputs=[complaint, hr, bp, spo2, temp], outputs=output_text) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Crushing chest pain and nausea", 110, "90/60", 94, 37.2], |
|
|
["Twisted ankle at football", 75, "125/85", 99, 36.8], |
|
|
], |
|
|
inputs=[complaint, hr, bp, spo2, temp] |
|
|
) |
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Starting Hybrid Server on port 7860...") |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|