NurseCitizenDeveloper commited on
Commit
6d0ec7f
·
verified ·
1 Parent(s): 047146f

remove user login requirement from agent_main.py

Browse files
Files changed (1) hide show
  1. agent_main.py +492 -492
agent_main.py CHANGED
@@ -1,492 +1,492 @@
1
- #!/usr/bin/env python3
2
- """
3
- NurseSim-Triage Hybrid Agent Entry Point
4
-
5
- This module combines the A2A API (for AgentBeats) and the Gradio UI (for Human/Demo)
6
- into a single FastAPI application listening on port 7860.
7
- """
8
-
9
- import os
10
- import json
11
- import secrets
12
- import torch
13
- import logging
14
- import uvicorn
15
- import asyncio
16
- import secrets
17
- import gradio as gr
18
- from contextlib import asynccontextmanager
19
- from fastapi import FastAPI, HTTPException, Request, Depends, Security, status
20
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
21
- from fastapi.responses import JSONResponse
22
- from fastapi.middleware.cors import CORSMiddleware
23
- from typing import Dict, Any
24
- from pydantic import BaseModel
25
- from typing import Optional
26
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
27
- from peft import PeftModel
28
-
29
- # PDS Client for NHS patient lookup
30
- from nursesim_rl.pds_client import PDSClient, PDSEnvironment, PatientDemographics, RestrictedPatientError
31
-
32
- # ==========================================
33
- # Data Models
34
- # ==========================================
35
-
36
- class Vitals(BaseModel):
37
- heart_rate: int = 80
38
- blood_pressure: str = "120/80"
39
- spo2: int = 98
40
- temperature: float = 37.0
41
-
42
- class TaskInput(BaseModel):
43
- complaint: str
44
- vitals: Vitals
45
- nhs_number: Optional[str] = None
46
- age: Optional[int] = None
47
- gender: Optional[str] = None
48
- relevant_pmh: Optional[str] = None
49
- rr: Optional[int] = 16
50
- avpu: Optional[str] = "A"
51
-
52
- # ==========================================
53
- # Agent Core Logic
54
- # ==========================================
55
-
56
- class NurseSimTriageAgent:
57
- """
58
- Shared agent logic for both API and UI.
59
- """
60
-
61
- def __init__(self):
62
- """Initialize the triage agent placeholder."""
63
- self.model = None
64
- self.tokenizer = None
65
- self.HF_TOKEN = os.environ.get("HF_TOKEN")
66
-
67
- # Initialize PDS client for NHS patient lookup (sandbox mode)
68
- self.pds_client = PDSClient(environment=PDSEnvironment.SANDBOX)
69
-
70
- if not self.HF_TOKEN:
71
- print("WARNING: HF_TOKEN not set. Model loading will fail if authentication is required.")
72
-
73
- async def load_model(self):
74
- """Load the base model and LoRA adapters asynchronously."""
75
- if self.model is not None:
76
- return
77
-
78
- try:
79
- print("⏳ Starting model load...")
80
- base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
81
- adapter_id = "NurseCitizenDeveloper/NurseSim-Triage-Llama-3.2-3B"
82
-
83
- # Offload heavy loading to thread
84
- await asyncio.to_thread(self._load_weights, base_model_id, adapter_id)
85
-
86
- print("✅ Model loaded successfully!")
87
- except Exception as e:
88
- print(f"❌ CRITICAL ERROR loading model: {e}")
89
- import traceback
90
- traceback.print_exc()
91
-
92
- def _load_weights(self, base_model_id, adapter_id):
93
- print(f"Loading tokenizer from {adapter_id}...")
94
- self.tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=self.HF_TOKEN)
95
-
96
- print(f"Loading base model {base_model_id} with 4-bit quantization...")
97
- bnb_config = BitsAndBytesConfig(
98
- load_in_4bit=True,
99
- bnb_4bit_compute_dtype=torch.float16,
100
- bnb_4bit_quant_type="nf4",
101
- bnb_4bit_use_double_quant=True,
102
- )
103
-
104
- self.model = AutoModelForCausalLM.from_pretrained(
105
- base_model_id,
106
- quantization_config=bnb_config,
107
- device_map="auto",
108
- low_cpu_mem_usage=True,
109
- token=self.HF_TOKEN,
110
- )
111
-
112
- print(f"Applying LoRA adapters from {adapter_id}...")
113
- self.model = PeftModel.from_pretrained(self.model, adapter_id, token=self.HF_TOKEN)
114
- self.model.eval()
115
-
116
- def get_response(self, complaint: str, hr: int, bp: str, spo2: int, temp: float, rr: int = 16, avpu: str = "A", age: int = 45, gender: str = "Male", pmh: str = "None") -> str:
117
- """Shared inference logic."""
118
- if self.model is None:
119
- return "⚠️ System is warming up. Please try again in 30 seconds."
120
-
121
- # Construct History Dictionary (Critical for Model Accuracy)
122
- history_dict = {
123
- 'age': int(age) if age else "Unknown",
124
- 'gender': gender,
125
- 'relevant_PMH': pmh if pmh else "None",
126
- 'time_course': "See complaint"
127
- }
128
-
129
- input_text = f"""PATIENT PRESENTING TO A&E TRIAGE
130
-
131
- Chief Complaint: "{complaint}"
132
-
133
- Vitals:
134
- - HR: {hr} bpm
135
- - BP: {bp} mmHg
136
- - SpO2: {spo2}%
137
- - RR: {rr} /min
138
- - Temp: {temp}C
139
- - AVPU: {avpu}
140
-
141
- History: {history_dict}
142
-
143
- WAITING ROOM: 12 patients | AVAILABLE BEDS: 4
144
-
145
- What is your triage decision?"""
146
-
147
- prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
148
-
149
- ### Instruction:
150
- You are an expert A&E Triage Nurse using the Manchester Triage System. Assess the following patient and provide your triage decision with clinical reasoning.
151
-
152
- ### Input:
153
- {input_text}
154
-
155
- ### Response:
156
- """
157
-
158
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
159
-
160
- with torch.no_grad():
161
- outputs = self.model.generate(
162
- **inputs,
163
- max_new_tokens=256,
164
- do_sample=True,
165
- temperature=0.6,
166
- top_p=0.9,
167
- pad_token_id=self.tokenizer.eos_token_id,
168
- )
169
-
170
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
171
- if "### Response:" in response:
172
- try:
173
- response = response.split("### Response:")[-1].strip()
174
- except Exception:
175
- pass
176
-
177
- return response
178
-
179
- def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
180
- """Process an API task, optionally fetching patient demographics from PDS."""
181
- if self.model is None:
182
- return {
183
- "error": "ModelStillLoading",
184
- "message": "The agent is still warming up. Please retry in 30 seconds."
185
- }
186
-
187
- try:
188
- complaint = task.get("complaint", "")
189
- vitals = task.get("vitals", {})
190
- nhs_number = task.get("nhs_number")
191
-
192
- # If NHS number provided, enrich with PDS data
193
- patient_info = None
194
- if nhs_number:
195
- try:
196
- patient_info = self.lookup_patient(nhs_number)
197
- except RestrictedPatientError as e:
198
- print(f"SECURITY ALERT: {e}")
199
- # Explicitly do NOT set patient_info so data is not leaked
200
- except Exception as e:
201
- print(f"PDS lookup failed: {e}")
202
-
203
- response = self.get_response(
204
- complaint,
205
- vitals.get("heart_rate", 80),
206
- vitals.get("blood_pressure", "120/80"),
207
- vitals.get("spo2", 98),
208
- vitals.get("temperature", 37.0)
209
- )
210
-
211
- result = {
212
- "triage_category": self._extract_triage_category(response),
213
- "assessment": response,
214
- "recommended_action": self._extract_recommended_action(response)
215
- }
216
-
217
- # Include patient info if retrieved
218
- if patient_info:
219
- result["patient"] = {
220
- "nhs_number": patient_info.nhs_number,
221
- "name": patient_info.full_name,
222
- "age": patient_info.age,
223
- "gender": patient_info.gender,
224
- "gp_practice": patient_info.gp_practice_name,
225
- }
226
-
227
- return result
228
-
229
- except Exception as e:
230
- logger.exception("Error processing task")
231
- return {"error": "Internal Processing Error", "triage_category": "Error"}
232
-
233
- def lookup_patient(self, nhs_number: str) -> PatientDemographics:
234
- """
235
- Look up patient demographics from NHS PDS.
236
-
237
- Args:
238
- nhs_number: 10-digit NHS number
239
-
240
- Returns:
241
- PatientDemographics object with patient details
242
- """
243
- return self.pds_client.lookup_patient_sync(nhs_number)
244
-
245
- def _extract_triage_category(self, response: str) -> str:
246
- response_lower = response.lower()
247
- if "immediate" in response_lower or "resuscitation" in response_lower: return "Immediate"
248
- elif "very urgent" in response_lower or "emergency" in response_lower: return "Very Urgent"
249
- elif "urgent" in response_lower: return "Urgent"
250
- elif "standard" in response_lower: return "Standard"
251
- elif "non-urgent" in response_lower or "non urgent" in response_lower: return "Non-Urgent"
252
- else: return "Standard"
253
-
254
- def _extract_recommended_action(self, response: str) -> str:
255
- if "monitor" in response.lower(): return "Monitor patient closely"
256
- elif "immediate" in response.lower() or "urgent" in response.lower(): return "Immediate medical attention required"
257
- else: return "Continue assessment and treatment as per protocol"
258
-
259
- def health_check(self) -> Dict[str, Any]:
260
- return {
261
- "status": "healthy" if self.model is not None else "loading",
262
- "model_loaded": self.model is not None,
263
- "gpu_available": torch.cuda.is_available()
264
- }
265
-
266
- # ==========================================
267
- # Application Setup
268
- # ==========================================
269
-
270
- # Configure logging
271
- logging.basicConfig(
272
- level=logging.INFO,
273
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
274
- )
275
- logger = logging.getLogger(__name__)
276
-
277
- agent = NurseSimTriageAgent()
278
-
279
- @asynccontextmanager
280
- async def lifespan(app: FastAPI):
281
- print("🚀 Server starting. Triggering model load task...")
282
- asyncio.create_task(agent.load_model())
283
- yield
284
- print("🛑 Server shutting down.")
285
-
286
- app = FastAPI(title="NurseSim-Triage Agent", version="1.2.0", lifespan=lifespan)
287
-
288
- app.add_middleware(
289
- CORSMiddleware,
290
- allow_origins=["*"],
291
- allow_credentials=True,
292
- allow_methods=["*"],
293
- allow_headers=["*"],
294
- )
295
-
296
- # ==========================================
297
- # Security
298
- # ==========================================
299
-
300
- security = HTTPBearer()
301
-
302
- async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
303
- """
304
- Verify API key or HF token from Authorization header.
305
- Fail-closed: If no keys are configured, all access is denied.
306
- """
307
- api_key = os.environ.get("API_KEY")
308
- hf_token = os.environ.get("HF_TOKEN")
309
-
310
- if not api_key and not hf_token:
311
- # System locked down if no keys configured
312
- raise HTTPException(
313
- status_code=status.HTTP_403_FORBIDDEN,
314
- detail="System misconfigured: No authentication keys set."
315
- )
316
-
317
- token = credentials.credentials
318
-
319
- # Check against available keys
320
- if api_key and secrets.compare_digest(token, api_key):
321
- return token
322
- if hf_token and secrets.compare_digest(token, hf_token):
323
- return token
324
-
325
- raise HTTPException(
326
- status_code=status.HTTP_401_UNAUTHORIZED,
327
- detail="Invalid authentication credentials",
328
- headers={"WWW-Authenticate": "Bearer"},
329
- )
330
-
331
- def get_gradio_auth():
332
- """
333
- Get authentication credentials for Gradio UI.
334
- Mirroring the API security: supports both API_KEY and HF_TOKEN.
335
- """
336
- auth_creds = []
337
- api_key = os.environ.get("API_KEY")
338
- hf_token = os.environ.get("HF_TOKEN")
339
-
340
- if api_key:
341
- auth_creds.append(("admin", api_key))
342
- if hf_token:
343
- auth_creds.append(("admin", hf_token))
344
-
345
- if not auth_creds:
346
- random_key = secrets.token_urlsafe(16)
347
- print(f"WARNING: No authentication keys set. Gradio UI locked with random key: {random_key}")
348
- auth_creds.append(("admin", random_key))
349
-
350
- return auth_creds
351
-
352
- # ==========================================
353
- # API Endpoints
354
- # ==========================================
355
-
356
- @app.get("/health")
357
- async def health_check():
358
- return agent.health_check()
359
-
360
- @app.get("/.well-known/agent-card.json")
361
- async def get_agent_card():
362
- card_path = ".well-known/agent-card.json"
363
- if os.path.exists(card_path):
364
- with open(card_path, "r") as f:
365
- return json.load(f)
366
- raise HTTPException(status_code=404, detail="Agent card not found")
367
-
368
- @app.post("/process-task", dependencies=[Depends(verify_api_key)])
369
- async def process_task(task: TaskInput):
370
- result = agent.process_task(task.dict())
371
- if "error" in result and result.get("message") == "ModelStillLoading":
372
- raise HTTPException(status_code=503, detail=result["message"])
373
- return result
374
-
375
- class PatientLookupRequest(BaseModel):
376
- nhs_number: str
377
-
378
- @app.post("/lookup-patient", dependencies=[Depends(verify_api_key)])
379
- async def api_lookup_patient(request: PatientLookupRequest):
380
- """Direct endpoint to lookup patient details from NHS PDS. Requires authentication."""
381
- try:
382
- patient = agent.lookup_patient(request.nhs_number)
383
- return {
384
- "nhs_number": patient.nhs_number,
385
- "full_name": patient.full_name,
386
- "date_of_birth": patient.date_of_birth,
387
- "age": patient.age,
388
- "gender": patient.gender,
389
- "address": patient.address,
390
- "gp_practice": patient.gp_practice_name
391
- }
392
- except RestrictedPatientError as e:
393
- logger.warning(f"Access denied for restricted patient: {request.nhs_number}")
394
- raise HTTPException(status_code=403, detail="🚫 ACCESS DENIED: Restricted Patient Record")
395
- except ValueError as e:
396
- raise HTTPException(status_code=400, detail=str(e))
397
- except Exception as e:
398
- logger.exception("Unexpected error during patient lookup")
399
- raise HTTPException(status_code=500, detail="Internal Server Error")
400
-
401
- # ==========================================
402
- # Gradio UI Integration
403
- # ==========================================
404
-
405
- def lookup_patient_ui(nhs_no):
406
- """Gradio handler for PDS lookup."""
407
- if not nhs_no:
408
- return 45, "Male", "", "Please enter an NHS Number."
409
- try:
410
- patient = agent.lookup_patient(nhs_no)
411
- pmh_context = f"Registered GP: {patient.gp_practice_name}"
412
- status_msg = f"✅ Verified: {patient.full_name}"
413
- return patient.age, patient.gender, pmh_context, status_msg
414
- except RestrictedPatientError:
415
- return 45, "Male", "", "🚫 ACCESS DENIED: Restricted Record"
416
- except Exception as e:
417
- return 45, "Male", "", f"❌ Lookup failed: {str(e)}"
418
-
419
- def gradio_predict(complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu):
420
- return agent.get_response(complaint, hr, bp, spo2, temp, rr, avpu, age, gender, pmh)
421
-
422
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate")) as demo:
423
- gr.Markdown("""
424
- # 🏥 NurseSim AI: Emergency Triage Simulator
425
- **An AI agent fine-tuned for the Manchester Triage System (MTS).**
426
-
427
- > ⚡ **Hybrid Mode**: Serving both Gradio UI and A2A API (AgentBeats)
428
- """)
429
-
430
- with gr.Row():
431
- with gr.Column(scale=1):
432
- gr.Markdown("### 1. Patient Demographics")
433
- with gr.Row():
434
- nhs_number = gr.Textbox(label="NHS Number", placeholder="e.g. 9000000009", scale=2)
435
- lookup_btn = gr.Button("🔍 Lookup", variant="secondary", scale=1)
436
- lookup_status = gr.Markdown("")
437
-
438
- age = gr.Number(label="Age", value=45)
439
- gender = gr.Radio(["Male", "Female"], label="Gender", value="Male")
440
- pmh = gr.Textbox(label="Medical History (PMH)", placeholder="e.g., Hypertension, Diabetes, Asthma", lines=2)
441
-
442
- gr.Markdown("### 2. Presentation")
443
- complaint = gr.Textbox(label="Chief Complaint", placeholder="e.g., Crushing chest pain radiating to jaw", lines=2)
444
-
445
- with gr.Column(scale=1):
446
- gr.Markdown("### 3. Vital Signs")
447
- with gr.Row():
448
- hr = gr.Number(label="HR (bpm)", value=80)
449
- rr = gr.Number(label="RR (breaths/min)", value=16)
450
- with gr.Row():
451
- bp = gr.Textbox(label="BP (mmHg)", value="120/80")
452
- spo2 = gr.Slider(label="SpO2 (%)", minimum=50, maximum=100, value=98)
453
- with gr.Row():
454
- temp = gr.Number(label="Temp (C)", value=37.0)
455
- avpu = gr.Dropdown(["A", "V", "P", "U"], label="AVPU", value="A")
456
-
457
- submit_btn = gr.Button("🚨 Assess Patient", variant="primary", size="lg")
458
-
459
- with gr.Row():
460
- output_text = gr.Textbox(label="AI Triage Assessment", lines=8)
461
- gr.Markdown("""
462
- ### ⚠️ Safety Disclaimer
463
- This system is a **research prototype**. It is **NOT** a certified medical device.
464
- """)
465
-
466
- lookup_btn.click(
467
- fn=lookup_patient_ui,
468
- inputs=[nhs_number],
469
- outputs=[age, gender, pmh, lookup_status]
470
- )
471
-
472
- submit_btn.click(
473
- fn=gradio_predict,
474
- inputs=[complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu],
475
- outputs=output_text
476
- )
477
-
478
- gr.Examples(
479
- examples=[
480
- ["Crushing chest pain and nausea", 72, "Male", "HTN, High Cholesterol", 110, "90/60", 94, 24, 37.2, "A"],
481
- ["Twisted ankle at football", 22, "Male", "None", 75, "125/85", 99, 14, 36.8, "A"],
482
- ],
483
- inputs=[complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu]
484
- )
485
-
486
- # Mount Gradio app to FastAPI at root
487
- # Secure the UI with the same credentials as the API
488
- app = gr.mount_gradio_app(app, demo, path="/", auth=get_gradio_auth())
489
-
490
- if __name__ == "__main__":
491
- print("Starting Hybrid Server on port 7860...")
492
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ NurseSim-Triage Hybrid Agent Entry Point
4
+
5
+ This module combines the A2A API (for AgentBeats) and the Gradio UI (for Human/Demo)
6
+ into a single FastAPI application listening on port 7860.
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import secrets
12
+ import torch
13
+ import logging
14
+ import uvicorn
15
+ import asyncio
16
+ import secrets
17
+ import gradio as gr
18
+ from contextlib import asynccontextmanager
19
+ from fastapi import FastAPI, HTTPException, Request, Depends, Security, status
20
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
21
+ from fastapi.responses import JSONResponse
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+ from typing import Dict, Any
24
+ from pydantic import BaseModel
25
+ from typing import Optional
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
27
+ from peft import PeftModel
28
+
29
+ # PDS Client for NHS patient lookup
30
+ from nursesim_rl.pds_client import PDSClient, PDSEnvironment, PatientDemographics, RestrictedPatientError
31
+
32
+ # ==========================================
33
+ # Data Models
34
+ # ==========================================
35
+
36
+ class Vitals(BaseModel):
37
+ heart_rate: int = 80
38
+ blood_pressure: str = "120/80"
39
+ spo2: int = 98
40
+ temperature: float = 37.0
41
+
42
+ class TaskInput(BaseModel):
43
+ complaint: str
44
+ vitals: Vitals
45
+ nhs_number: Optional[str] = None
46
+ age: Optional[int] = None
47
+ gender: Optional[str] = None
48
+ relevant_pmh: Optional[str] = None
49
+ rr: Optional[int] = 16
50
+ avpu: Optional[str] = "A"
51
+
52
+ # ==========================================
53
+ # Agent Core Logic
54
+ # ==========================================
55
+
56
+ class NurseSimTriageAgent:
57
+ """
58
+ Shared agent logic for both API and UI.
59
+ """
60
+
61
+ def __init__(self):
62
+ """Initialize the triage agent placeholder."""
63
+ self.model = None
64
+ self.tokenizer = None
65
+ self.HF_TOKEN = os.environ.get("HF_TOKEN")
66
+
67
+ # Initialize PDS client for NHS patient lookup (sandbox mode)
68
+ self.pds_client = PDSClient(environment=PDSEnvironment.SANDBOX)
69
+
70
+ if not self.HF_TOKEN:
71
+ print("WARNING: HF_TOKEN not set. Model loading will fail if authentication is required.")
72
+
73
+ async def load_model(self):
74
+ """Load the base model and LoRA adapters asynchronously."""
75
+ if self.model is not None:
76
+ return
77
+
78
+ try:
79
+ print("⏳ Starting model load...")
80
+ base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
81
+ adapter_id = "NurseCitizenDeveloper/NurseSim-Triage-Llama-3.2-3B"
82
+
83
+ # Offload heavy loading to thread
84
+ await asyncio.to_thread(self._load_weights, base_model_id, adapter_id)
85
+
86
+ print("✅ Model loaded successfully!")
87
+ except Exception as e:
88
+ print(f"❌ CRITICAL ERROR loading model: {e}")
89
+ import traceback
90
+ traceback.print_exc()
91
+
92
+ def _load_weights(self, base_model_id, adapter_id):
93
+ print(f"Loading tokenizer from {adapter_id}...")
94
+ self.tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=self.HF_TOKEN)
95
+
96
+ print(f"Loading base model {base_model_id} with 4-bit quantization...")
97
+ bnb_config = BitsAndBytesConfig(
98
+ load_in_4bit=True,
99
+ bnb_4bit_compute_dtype=torch.float16,
100
+ bnb_4bit_quant_type="nf4",
101
+ bnb_4bit_use_double_quant=True,
102
+ )
103
+
104
+ self.model = AutoModelForCausalLM.from_pretrained(
105
+ base_model_id,
106
+ quantization_config=bnb_config,
107
+ device_map="auto",
108
+ low_cpu_mem_usage=True,
109
+ token=self.HF_TOKEN,
110
+ )
111
+
112
+ print(f"Applying LoRA adapters from {adapter_id}...")
113
+ self.model = PeftModel.from_pretrained(self.model, adapter_id, token=self.HF_TOKEN)
114
+ self.model.eval()
115
+
116
+ def get_response(self, complaint: str, hr: int, bp: str, spo2: int, temp: float, rr: int = 16, avpu: str = "A", age: int = 45, gender: str = "Male", pmh: str = "None") -> str:
117
+ """Shared inference logic."""
118
+ if self.model is None:
119
+ return "⚠️ System is warming up. Please try again in 30 seconds."
120
+
121
+ # Construct History Dictionary (Critical for Model Accuracy)
122
+ history_dict = {
123
+ 'age': int(age) if age else "Unknown",
124
+ 'gender': gender,
125
+ 'relevant_PMH': pmh if pmh else "None",
126
+ 'time_course': "See complaint"
127
+ }
128
+
129
+ input_text = f"""PATIENT PRESENTING TO A&E TRIAGE
130
+
131
+ Chief Complaint: "{complaint}"
132
+
133
+ Vitals:
134
+ - HR: {hr} bpm
135
+ - BP: {bp} mmHg
136
+ - SpO2: {spo2}%
137
+ - RR: {rr} /min
138
+ - Temp: {temp}C
139
+ - AVPU: {avpu}
140
+
141
+ History: {history_dict}
142
+
143
+ WAITING ROOM: 12 patients | AVAILABLE BEDS: 4
144
+
145
+ What is your triage decision?"""
146
+
147
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
148
+
149
+ ### Instruction:
150
+ You are an expert A&E Triage Nurse using the Manchester Triage System. Assess the following patient and provide your triage decision with clinical reasoning.
151
+
152
+ ### Input:
153
+ {input_text}
154
+
155
+ ### Response:
156
+ """
157
+
158
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
159
+
160
+ with torch.no_grad():
161
+ outputs = self.model.generate(
162
+ **inputs,
163
+ max_new_tokens=256,
164
+ do_sample=True,
165
+ temperature=0.6,
166
+ top_p=0.9,
167
+ pad_token_id=self.tokenizer.eos_token_id,
168
+ )
169
+
170
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
171
+ if "### Response:" in response:
172
+ try:
173
+ response = response.split("### Response:")[-1].strip()
174
+ except Exception:
175
+ pass
176
+
177
+ return response
178
+
179
+ def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
180
+ """Process an API task, optionally fetching patient demographics from PDS."""
181
+ if self.model is None:
182
+ return {
183
+ "error": "ModelStillLoading",
184
+ "message": "The agent is still warming up. Please retry in 30 seconds."
185
+ }
186
+
187
+ try:
188
+ complaint = task.get("complaint", "")
189
+ vitals = task.get("vitals", {})
190
+ nhs_number = task.get("nhs_number")
191
+
192
+ # If NHS number provided, enrich with PDS data
193
+ patient_info = None
194
+ if nhs_number:
195
+ try:
196
+ patient_info = self.lookup_patient(nhs_number)
197
+ except RestrictedPatientError as e:
198
+ print(f"SECURITY ALERT: {e}")
199
+ # Explicitly do NOT set patient_info so data is not leaked
200
+ except Exception as e:
201
+ print(f"PDS lookup failed: {e}")
202
+
203
+ response = self.get_response(
204
+ complaint,
205
+ vitals.get("heart_rate", 80),
206
+ vitals.get("blood_pressure", "120/80"),
207
+ vitals.get("spo2", 98),
208
+ vitals.get("temperature", 37.0)
209
+ )
210
+
211
+ result = {
212
+ "triage_category": self._extract_triage_category(response),
213
+ "assessment": response,
214
+ "recommended_action": self._extract_recommended_action(response)
215
+ }
216
+
217
+ # Include patient info if retrieved
218
+ if patient_info:
219
+ result["patient"] = {
220
+ "nhs_number": patient_info.nhs_number,
221
+ "name": patient_info.full_name,
222
+ "age": patient_info.age,
223
+ "gender": patient_info.gender,
224
+ "gp_practice": patient_info.gp_practice_name,
225
+ }
226
+
227
+ return result
228
+
229
+ except Exception as e:
230
+ logger.exception("Error processing task")
231
+ return {"error": "Internal Processing Error", "triage_category": "Error"}
232
+
233
+ def lookup_patient(self, nhs_number: str) -> PatientDemographics:
234
+ """
235
+ Look up patient demographics from NHS PDS.
236
+
237
+ Args:
238
+ nhs_number: 10-digit NHS number
239
+
240
+ Returns:
241
+ PatientDemographics object with patient details
242
+ """
243
+ return self.pds_client.lookup_patient_sync(nhs_number)
244
+
245
+ def _extract_triage_category(self, response: str) -> str:
246
+ response_lower = response.lower()
247
+ if "immediate" in response_lower or "resuscitation" in response_lower: return "Immediate"
248
+ elif "very urgent" in response_lower or "emergency" in response_lower: return "Very Urgent"
249
+ elif "urgent" in response_lower: return "Urgent"
250
+ elif "standard" in response_lower: return "Standard"
251
+ elif "non-urgent" in response_lower or "non urgent" in response_lower: return "Non-Urgent"
252
+ else: return "Standard"
253
+
254
+ def _extract_recommended_action(self, response: str) -> str:
255
+ if "monitor" in response.lower(): return "Monitor patient closely"
256
+ elif "immediate" in response.lower() or "urgent" in response.lower(): return "Immediate medical attention required"
257
+ else: return "Continue assessment and treatment as per protocol"
258
+
259
+ def health_check(self) -> Dict[str, Any]:
260
+ return {
261
+ "status": "healthy" if self.model is not None else "loading",
262
+ "model_loaded": self.model is not None,
263
+ "gpu_available": torch.cuda.is_available()
264
+ }
265
+
266
+ # ==========================================
267
+ # Application Setup
268
+ # ==========================================
269
+
270
+ # Configure logging
271
+ logging.basicConfig(
272
+ level=logging.INFO,
273
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
274
+ )
275
+ logger = logging.getLogger(__name__)
276
+
277
+ agent = NurseSimTriageAgent()
278
+
279
+ @asynccontextmanager
280
+ async def lifespan(app: FastAPI):
281
+ print("🚀 Server starting. Triggering model load task...")
282
+ asyncio.create_task(agent.load_model())
283
+ yield
284
+ print("🛑 Server shutting down.")
285
+
286
+ app = FastAPI(title="NurseSim-Triage Agent", version="1.2.0", lifespan=lifespan)
287
+
288
+ app.add_middleware(
289
+ CORSMiddleware,
290
+ allow_origins=["*"],
291
+ allow_credentials=True,
292
+ allow_methods=["*"],
293
+ allow_headers=["*"],
294
+ )
295
+
296
+ # ==========================================
297
+ # Security
298
+ # ==========================================
299
+
300
+ security = HTTPBearer()
301
+
302
+ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
303
+ """
304
+ Verify API key or HF token from Authorization header.
305
+ Fail-closed: If no keys are configured, all access is denied.
306
+ """
307
+ api_key = os.environ.get("API_KEY")
308
+ hf_token = os.environ.get("HF_TOKEN")
309
+
310
+ if not api_key and not hf_token:
311
+ # System locked down if no keys configured
312
+ raise HTTPException(
313
+ status_code=status.HTTP_403_FORBIDDEN,
314
+ detail="System misconfigured: No authentication keys set."
315
+ )
316
+
317
+ token = credentials.credentials
318
+
319
+ # Check against available keys
320
+ if api_key and secrets.compare_digest(token, api_key):
321
+ return token
322
+ if hf_token and secrets.compare_digest(token, hf_token):
323
+ return token
324
+
325
+ raise HTTPException(
326
+ status_code=status.HTTP_401_UNAUTHORIZED,
327
+ detail="Invalid authentication credentials",
328
+ headers={"WWW-Authenticate": "Bearer"},
329
+ )
330
+
331
+ def get_gradio_auth():
332
+ """
333
+ Get authentication credentials for Gradio UI.
334
+ Mirroring the API security: supports both API_KEY and HF_TOKEN.
335
+ """
336
+ auth_creds = []
337
+ api_key = os.environ.get("API_KEY")
338
+ hf_token = os.environ.get("HF_TOKEN")
339
+
340
+ if api_key:
341
+ auth_creds.append(("admin", api_key))
342
+ if hf_token:
343
+ auth_creds.append(("admin", hf_token))
344
+
345
+ if not auth_creds:
346
+ random_key = secrets.token_urlsafe(16)
347
+ print(f"WARNING: No authentication keys set. Gradio UI locked with random key: {random_key}")
348
+ auth_creds.append(("admin", random_key))
349
+
350
+ return auth_creds
351
+
352
+ # ==========================================
353
+ # API Endpoints
354
+ # ==========================================
355
+
356
+ @app.get("/health")
357
+ async def health_check():
358
+ return agent.health_check()
359
+
360
+ @app.get("/.well-known/agent-card.json")
361
+ async def get_agent_card():
362
+ card_path = ".well-known/agent-card.json"
363
+ if os.path.exists(card_path):
364
+ with open(card_path, "r") as f:
365
+ return json.load(f)
366
+ raise HTTPException(status_code=404, detail="Agent card not found")
367
+
368
+ @app.post("/process-task", dependencies=[Depends(verify_api_key)])
369
+ async def process_task(task: TaskInput):
370
+ result = agent.process_task(task.dict())
371
+ if "error" in result and result.get("message") == "ModelStillLoading":
372
+ raise HTTPException(status_code=503, detail=result["message"])
373
+ return result
374
+
375
+ class PatientLookupRequest(BaseModel):
376
+ nhs_number: str
377
+
378
+ @app.post("/lookup-patient", dependencies=[Depends(verify_api_key)])
379
+ async def api_lookup_patient(request: PatientLookupRequest):
380
+ """Direct endpoint to lookup patient details from NHS PDS. Requires authentication."""
381
+ try:
382
+ patient = agent.lookup_patient(request.nhs_number)
383
+ return {
384
+ "nhs_number": patient.nhs_number,
385
+ "full_name": patient.full_name,
386
+ "date_of_birth": patient.date_of_birth,
387
+ "age": patient.age,
388
+ "gender": patient.gender,
389
+ "address": patient.address,
390
+ "gp_practice": patient.gp_practice_name
391
+ }
392
+ except RestrictedPatientError as e:
393
+ logger.warning(f"Access denied for restricted patient: {request.nhs_number}")
394
+ raise HTTPException(status_code=403, detail="🚫 ACCESS DENIED: Restricted Patient Record")
395
+ except ValueError as e:
396
+ raise HTTPException(status_code=400, detail=str(e))
397
+ except Exception as e:
398
+ logger.exception("Unexpected error during patient lookup")
399
+ raise HTTPException(status_code=500, detail="Internal Server Error")
400
+
401
+ # ==========================================
402
+ # Gradio UI Integration
403
+ # ==========================================
404
+
405
+ def lookup_patient_ui(nhs_no):
406
+ """Gradio handler for PDS lookup."""
407
+ if not nhs_no:
408
+ return 45, "Male", "", "Please enter an NHS Number."
409
+ try:
410
+ patient = agent.lookup_patient(nhs_no)
411
+ pmh_context = f"Registered GP: {patient.gp_practice_name}"
412
+ status_msg = f"✅ Verified: {patient.full_name}"
413
+ return patient.age, patient.gender, pmh_context, status_msg
414
+ except RestrictedPatientError:
415
+ return 45, "Male", "", "🚫 ACCESS DENIED: Restricted Record"
416
+ except Exception as e:
417
+ return 45, "Male", "", f"❌ Lookup failed: {str(e)}"
418
+
419
+ def gradio_predict(complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu):
420
+ return agent.get_response(complaint, hr, bp, spo2, temp, rr, avpu, age, gender, pmh)
421
+
422
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate")) as demo:
423
+ gr.Markdown("""
424
+ # 🏥 NurseSim AI: Emergency Triage Simulator
425
+ **An AI agent fine-tuned for the Manchester Triage System (MTS).**
426
+
427
+ > ⚡ **Hybrid Mode**: Serving both Gradio UI and A2A API (AgentBeats)
428
+ """)
429
+
430
+ with gr.Row():
431
+ with gr.Column(scale=1):
432
+ gr.Markdown("### 1. Patient Demographics")
433
+ with gr.Row():
434
+ nhs_number = gr.Textbox(label="NHS Number", placeholder="e.g. 9000000009", scale=2)
435
+ lookup_btn = gr.Button("🔍 Lookup", variant="secondary", scale=1)
436
+ lookup_status = gr.Markdown("")
437
+
438
+ age = gr.Number(label="Age", value=45)
439
+ gender = gr.Radio(["Male", "Female"], label="Gender", value="Male")
440
+ pmh = gr.Textbox(label="Medical History (PMH)", placeholder="e.g., Hypertension, Diabetes, Asthma", lines=2)
441
+
442
+ gr.Markdown("### 2. Presentation")
443
+ complaint = gr.Textbox(label="Chief Complaint", placeholder="e.g., Crushing chest pain radiating to jaw", lines=2)
444
+
445
+ with gr.Column(scale=1):
446
+ gr.Markdown("### 3. Vital Signs")
447
+ with gr.Row():
448
+ hr = gr.Number(label="HR (bpm)", value=80)
449
+ rr = gr.Number(label="RR (breaths/min)", value=16)
450
+ with gr.Row():
451
+ bp = gr.Textbox(label="BP (mmHg)", value="120/80")
452
+ spo2 = gr.Slider(label="SpO2 (%)", minimum=50, maximum=100, value=98)
453
+ with gr.Row():
454
+ temp = gr.Number(label="Temp (C)", value=37.0)
455
+ avpu = gr.Dropdown(["A", "V", "P", "U"], label="AVPU", value="A")
456
+
457
+ submit_btn = gr.Button("🚨 Assess Patient", variant="primary", size="lg")
458
+
459
+ with gr.Row():
460
+ output_text = gr.Textbox(label="AI Triage Assessment", lines=8)
461
+ gr.Markdown("""
462
+ ### ⚠️ Safety Disclaimer
463
+ This system is a **research prototype**. It is **NOT** a certified medical device.
464
+ """)
465
+
466
+ lookup_btn.click(
467
+ fn=lookup_patient_ui,
468
+ inputs=[nhs_number],
469
+ outputs=[age, gender, pmh, lookup_status]
470
+ )
471
+
472
+ submit_btn.click(
473
+ fn=gradio_predict,
474
+ inputs=[complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu],
475
+ outputs=output_text
476
+ )
477
+
478
+ gr.Examples(
479
+ examples=[
480
+ ["Crushing chest pain and nausea", 72, "Male", "HTN, High Cholesterol", 110, "90/60", 94, 24, 37.2, "A"],
481
+ ["Twisted ankle at football", 22, "Male", "None", 75, "125/85", 99, 14, 36.8, "A"],
482
+ ],
483
+ inputs=[complaint, age, gender, pmh, hr, bp, spo2, rr, temp, avpu]
484
+ )
485
+
486
+ # Mount Gradio app to FastAPI at root
487
+ # Secure the UI with the same credentials as the API
488
+ app = gr.mount_gradio_app(app, demo, path="/")
489
+
490
+ if __name__ == "__main__":
491
+ print("Starting Hybrid Server on port 7860...")
492
+ uvicorn.run(app, host="0.0.0.0", port=7860)