GitHub Actions commited on
Commit
442387a
Β·
1 Parent(s): 9f0380a

πŸš€ Auto-sync from GitHub: 56b626d

Browse files
Files changed (2) hide show
  1. agent_main.py +75 -50
  2. run.sh +8 -4
agent_main.py CHANGED
@@ -10,12 +10,14 @@ import os
10
  import json
11
  import torch
12
  import uvicorn
 
 
13
  from fastapi import FastAPI, HTTPException, Request
14
- from fastapi.responses import JSONResponse, FileResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
- from typing import Dict, Any, Optional
17
  from pydantic import BaseModel
18
- from transformers import AutoModelForCausalLM, AutoTokenizer
19
  from peft import PeftModel
20
 
21
  # ==========================================
@@ -42,57 +44,73 @@ class NurseSimTriageAgent:
42
  """
43
 
44
  def __init__(self):
45
- """Initialize the triage agent and load the model."""
46
  self.model = None
47
  self.tokenizer = None
48
  self.HF_TOKEN = os.environ.get("HF_TOKEN")
49
 
50
  if not self.HF_TOKEN:
51
  print("WARNING: HF_TOKEN not set. Model loading will fail if authentication is required.")
52
-
53
- self._load_model()
54
 
55
- def _load_model(self):
56
- """Load the base model and LoRA adapters."""
57
  if self.model is not None:
58
- return # Already loaded
59
-
60
  try:
 
61
  base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
62
  adapter_id = "NurseCitizenDeveloper/NurseSim-Triage-Llama-3.2-3B"
63
 
64
- print(f"Loading tokenizer from {adapter_id}...")
65
- self.tokenizer = AutoTokenizer.from_pretrained(
66
- adapter_id,
67
- token=self.HF_TOKEN
68
- )
69
-
70
- print(f"Loading base model {base_model_id}...")
71
- # Use device_map="auto" to handle CPU/GPU automatically
72
- self.model = AutoModelForCausalLM.from_pretrained(
73
- base_model_id,
74
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
75
- device_map="auto",
76
- low_cpu_mem_usage=True,
77
- token=self.HF_TOKEN,
78
- )
79
 
80
- print(f"Applying LoRA adapters from {adapter_id}...")
81
- self.model = PeftModel.from_pretrained(
82
- self.model,
83
- adapter_id,
84
- token=self.HF_TOKEN
85
- )
86
- self.model.eval()
87
- print(f"Model loaded successfully on {self.model.device}!")
88
  except Exception as e:
89
- print(f"CRITICAL ERROR loading model: {e}")
90
- # We don't raise here to allow the server to start (and report unhealthy status)
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
93
  """Process an A2A task and return the triage assessment."""
94
  if self.model is None:
95
- return {"error": "Model not loaded", "triage_category": "Error"}
 
 
 
96
 
97
  try:
98
  # Extract task data
@@ -163,23 +181,30 @@ Vitals: HR {hr}, BP {bp}, SpO2 {spo2}%, Temp {temp}C.
163
  def health_check(self) -> Dict[str, Any]:
164
  """Return agent health status."""
165
  return {
166
- "status": "healthy" if self.model is not None else "unhealthy",
167
  "model_loaded": self.model is not None,
168
- "gpu_available": torch.cuda.is_available(),
169
- "device": str(self.model.device) if self.model else "not loaded"
170
  }
171
 
172
  # ==========================================
173
- # FastAPI Server Setup
174
  # ==========================================
175
 
176
- print("Initializing NurseSim-Triage Agent...")
177
  agent = NurseSimTriageAgent()
178
 
 
 
 
 
 
 
 
 
 
179
  app = FastAPI(
180
  title="NurseSim-Triage Agent",
181
- description="A2A Interface for Clinical Triage",
182
- version="1.0.0"
183
  )
184
 
185
  app.add_middleware(
@@ -192,7 +217,10 @@ app.add_middleware(
192
 
193
  @app.get("/")
194
  async def root():
195
- return {"message": "NurseSim-Triage Agent is running. Visit /health for status."}
 
 
 
196
 
197
  @app.get("/health")
198
  async def health_check():
@@ -208,11 +236,9 @@ async def get_agent_card():
208
 
209
  @app.post("/process-task")
210
  async def process_task(task: TaskInput):
211
- """
212
- Standard A2A task processing endpoint.
213
- Accepts JSON body matching TaskInput schema.
214
- """
215
  result = agent.process_task(task.dict())
 
 
216
  return result
217
 
218
  # ==========================================
@@ -221,5 +247,4 @@ async def process_task(task: TaskInput):
221
 
222
  if __name__ == "__main__":
223
  print("Starting A2A Server on port 8080...")
224
- # Listen on all interfaces (0.0.0.0) for Docker support
225
  uvicorn.run(app, host="0.0.0.0", port=8080)
 
10
  import json
11
  import torch
12
  import uvicorn
13
+ import asyncio
14
+ from contextlib import asynccontextmanager
15
  from fastapi import FastAPI, HTTPException, Request
16
+ from fastapi.responses import JSONResponse
17
  from fastapi.middleware.cors import CORSMiddleware
18
+ from typing import Dict, Any
19
  from pydantic import BaseModel
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
21
  from peft import PeftModel
22
 
23
  # ==========================================
 
44
  """
45
 
46
  def __init__(self):
47
+ """Initialize the triage agent placeholder."""
48
  self.model = None
49
  self.tokenizer = None
50
  self.HF_TOKEN = os.environ.get("HF_TOKEN")
51
 
52
  if not self.HF_TOKEN:
53
  print("WARNING: HF_TOKEN not set. Model loading will fail if authentication is required.")
 
 
54
 
55
+ async def load_model(self):
56
+ """Load the base model and LoRA adapters asynchronously."""
57
  if self.model is not None:
58
+ return
59
+
60
  try:
61
+ print("⏳ Starting model load...")
62
  base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
63
  adapter_id = "NurseCitizenDeveloper/NurseSim-Triage-Llama-3.2-3B"
64
 
65
+ # Offload heavy loading to thread to avoid blocking event loop
66
+ await asyncio.to_thread(self._load_weights, base_model_id, adapter_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ print("βœ… Model loaded successfully!")
 
 
 
 
 
 
 
69
  except Exception as e:
70
+ print(f"❌ CRITICAL ERROR loading model: {e}")
71
+ import traceback
72
+ traceback.print_exc()
73
+
74
+ def _load_weights(self, base_model_id, adapter_id):
75
+ print(f"Loading tokenizer from {adapter_id}...")
76
+ self.tokenizer = AutoTokenizer.from_pretrained(
77
+ adapter_id,
78
+ token=self.HF_TOKEN
79
+ )
80
+
81
+ print(f"Loading base model {base_model_id} with 4-bit quantization...")
82
+
83
+ # Modern 4-bit loading configuration
84
+ bnb_config = BitsAndBytesConfig(
85
+ load_in_4bit=True,
86
+ bnb_4bit_compute_dtype=torch.float16,
87
+ bnb_4bit_quant_type="nf4",
88
+ bnb_4bit_use_double_quant=True,
89
+ )
90
+
91
+ self.model = AutoModelForCausalLM.from_pretrained(
92
+ base_model_id,
93
+ quantization_config=bnb_config,
94
+ device_map="auto",
95
+ low_cpu_mem_usage=True,
96
+ token=self.HF_TOKEN,
97
+ )
98
+
99
+ print(f"Applying LoRA adapters from {adapter_id}...")
100
+ self.model = PeftModel.from_pretrained(
101
+ self.model,
102
+ adapter_id,
103
+ token=self.HF_TOKEN
104
+ )
105
+ self.model.eval()
106
+
107
  def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
108
  """Process an A2A task and return the triage assessment."""
109
  if self.model is None:
110
+ return {
111
+ "error": "ModelStillLoading",
112
+ "message": "The agent is still warming up. Please retry in 30 seconds."
113
+ }
114
 
115
  try:
116
  # Extract task data
 
181
  def health_check(self) -> Dict[str, Any]:
182
  """Return agent health status."""
183
  return {
184
+ "status": "healthy" if self.model is not None else "loading",
185
  "model_loaded": self.model is not None,
186
+ "gpu_available": torch.cuda.is_available()
 
187
  }
188
 
189
  # ==========================================
190
+ # FastAPI Lifecycle & App
191
  # ==========================================
192
 
 
193
  agent = NurseSimTriageAgent()
194
 
195
+ @asynccontextmanager
196
+ async def lifespan(app: FastAPI):
197
+ # Startup: Load model in background
198
+ print("πŸš€ Server starting. Triggering model load task...")
199
+ asyncio.create_task(agent.load_model())
200
+ yield
201
+ # Shutdown logic (if any)
202
+ print("πŸ›‘ Server shutting down.")
203
+
204
  app = FastAPI(
205
  title="NurseSim-Triage Agent",
206
+ version="1.1.0",
207
+ lifespan=lifespan
208
  )
209
 
210
  app.add_middleware(
 
217
 
218
  @app.get("/")
219
  async def root():
220
+ return {
221
+ "message": "NurseSim-Triage Agent Online",
222
+ "status": agent.health_check()["status"]
223
+ }
224
 
225
  @app.get("/health")
226
  async def health_check():
 
236
 
237
  @app.post("/process-task")
238
  async def process_task(task: TaskInput):
 
 
 
 
239
  result = agent.process_task(task.dict())
240
+ if "error" in result and result.get("message") == "ModelStillLoading":
241
+ raise HTTPException(status_code=503, detail=result["message"])
242
  return result
243
 
244
  # ==========================================
 
247
 
248
  if __name__ == "__main__":
249
  print("Starting A2A Server on port 8080...")
 
250
  uvicorn.run(app, host="0.0.0.0", port=8080)
run.sh CHANGED
@@ -1,11 +1,15 @@
1
  #!/bin/bash
2
  # Launcher script for NurseSim-Triage agent
3
- # Supports dual-mode deployment: Gradio (human UI) or A2A (platform integration)
4
-
5
  set -e
6
 
7
- # Fix for libgomp Runtime Error on Hugging Face Spaces (CPU Upgrade/Basic)
8
- export OMP_NUM_THREADS=1
 
 
 
 
 
 
9
 
10
 
11
  AGENT_MODE=${AGENT_MODE:-a2a}
 
1
  #!/bin/bash
2
  # Launcher script for NurseSim-Triage agent
 
 
3
  set -e
4
 
5
+ # Fix for libgomp Runtime Error on Hugging Face Spaces (CPU Upgrade only)
6
+ # If NO GPU (CUDA devices empty), restrict threads to avoid crash
7
+ if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
8
+ echo "Running on CPU - Setting OMP_NUM_THREADS=1"
9
+ export OMP_NUM_THREADS=1
10
+ else
11
+ echo "Running on GPU (detected) - Allowing automatic threading"
12
+ fi
13
 
14
 
15
  AGENT_MODE=${AGENT_MODE:-a2a}