priyansh-saxena1 commited on
Commit
4e16e37
Β·
1 Parent(s): 0b46033

feat: migrate inference engine to Ollama for 10x faster CPU inference

Browse files
Files changed (5) hide show
  1. Dockerfile +26 -14
  2. README.md +26 -41
  3. app/llm.py +34 -73
  4. requirements.txt +3 -5
  5. startup.sh +37 -0
Dockerfile CHANGED
@@ -1,26 +1,38 @@
 
 
 
1
  FROM python:3.11-slim
2
 
 
 
 
 
 
 
 
 
 
 
3
  WORKDIR /app
4
 
 
5
  COPY requirements.txt .
6
-
7
- # CPU-only torch (~220MB vs 2.4GB CUDA wheel)
8
- RUN pip install --no-cache-dir torch --extra-index-url https://download.pytorch.org/whl/cpu
9
  RUN pip install --no-cache-dir -r requirements.txt
10
 
11
- # Pre-download model weights at build time (baked into image)
12
- # Swap model name here if you want a bigger one
13
- ARG MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct
14
- RUN python -c "from transformers import AutoModelForCausalLM, AutoTokenizer; \
15
- AutoTokenizer.from_pretrained('${MODEL_NAME}'); \
16
- AutoModelForCausalLM.from_pretrained('${MODEL_NAME}')"
17
-
18
- ENV MOCK_LLM=false
19
- ENV MODEL_NAME=${MODEL_NAME}
20
-
21
  COPY app/ ./app/
22
  COPY tests/ ./tests/
 
 
 
 
 
 
 
 
 
23
 
24
  EXPOSE 7860
25
 
26
- CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
1
+ # ─── Stage: Base ──────────────────────────────────────────────────────────────
2
+ # Hugging Face Spaces uses port 7860 by default.
3
+ # We install Ollama (llama.cpp under the hood) for fast CPU inference.
4
  FROM python:3.11-slim
5
 
6
+ # System dependencies for Ollama install script + curl
7
+ RUN apt-get update && apt-get install -y \
8
+ curl \
9
+ ca-certificates \
10
+ bash \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # ─── Install Ollama ───────────────────────────────────────────────────────────
14
+ RUN curl -fsSL https://ollama.com/install.sh | bash
15
+
16
  WORKDIR /app
17
 
18
+ # ─── Python dependencies ──────────────────────────────────────────────────────
19
  COPY requirements.txt .
 
 
 
20
  RUN pip install --no-cache-dir -r requirements.txt
21
 
22
+ # ─── Copy source code ─────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
23
  COPY app/ ./app/
24
  COPY tests/ ./tests/
25
+ COPY startup.sh .
26
+
27
+ RUN chmod +x startup.sh
28
+
29
+ # ─── Environment ──────────────────────────────────────────────────────────────
30
+ # Set MOCK_LLM=false to use Ollama. Override at runtime if needed for testing.
31
+ ENV MOCK_LLM=false
32
+ ENV MODEL_NAME=qwen2.5:0.5b
33
+ ENV OLLAMA_HOST=http://localhost:11434
34
 
35
  EXPOSE 7860
36
 
37
+ # startup.sh: boots Ollama, pulls model, starts FastAPI
38
+ CMD ["./startup.sh"]
README.md CHANGED
@@ -23,60 +23,45 @@ A LangGraph-based conversational agent for conducting pre-visit clinical intakes
23
  ## Architecture
24
 
25
  ```
26
- intake β†’ hpi β†’ ros β†’ brief_generation β†’ done
27
  ```
28
 
29
- ### State Graph (LangGraph TypedDict)
30
-
31
- ```python
32
- class IntakeState(TypedDict):
33
- messages: list[dict] # conversation history
34
- chief_complaint: str
35
- hpi: dict # onset, location, duration, character, severity, aggravating, relieving
36
- ros: dict[str, list[str]] # system -> [positive findings, negative findings]
37
- current_node: str
38
- clinical_brief: Optional[ClinicalBrief]
39
- ros_systems: list[str]
40
- ros_current_index: int
41
- ros_pending_system: Optional[str]
42
- last_processed_message_index: int
43
- vague_retry_field: Optional[str]
44
- ```
45
-
46
- ### Nodes
47
 
48
- 1. **intake_node**: Greets patient, extracts chief complaint. Moves to hpi when CC is clear.
49
- 2. **hpi_node**: Asks OPQRST questions one at a time. Re-prompts gracefully on vague answers.
50
- 3. **ros_node**: CONDITIONAL - scopes ROS systems based on CC (e.g., chest pain β†’ cardiac, respiratory, GI).
51
- 4. **brief_generator_node**: Generates Pydantic ClinicalBrief from state (no LLM call).
52
 
53
- ## Installation
54
 
55
- ### Local Development
 
 
56
 
57
- ```bash
58
- # Clone repository
59
- git clone <repo-url>
60
- cd clinical-intake-agent
61
 
62
- # Install dependencies
63
- pip install -r requirements.txt
64
 
65
- # Run with Mock LLM (default)
66
- export MOCK_LLM=true
67
- uvicorn app.main:app --reload
68
 
69
- # Run with Real LLM (requires model download)
70
- export MOCK_LLM=false
71
- uvicorn app.main:app --reload
 
72
  ```
73
 
74
- ### Docker (HuggingFace Spaces)
75
 
76
  ```bash
77
- # Build and run locally
78
- docker build -t clinical-intake-agent .
79
- docker run -p 7860:7860 -e MOCK_LLM=true clinical-intake-agent
 
 
 
 
80
  ```
81
 
82
  ## Usage
 
23
  ## Architecture
24
 
25
  ```
26
+ Patient β†’ triage_node β†’ agent_node β†’ (done or loop back for next question)
27
  ```
28
 
29
+ ### Inference Engine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ - **Local dev (mock)**: `MOCK_LLM=true` β€” regex-based MockLLM, 0ms latency
32
+ - **Production**: `MOCK_LLM=false` β€” **Ollama** local server (`qwen2.5:0.5b`, C++ optimized)
33
+ - ~2s per turn on CPU vs 25s with raw PyTorch
 
34
 
35
+ ### State Graph Nodes
36
 
37
+ 1. **triage_node**: Detects acute emergency phrases β†’ immediate 🚨 alert
38
+ 2. **agent_node**: Single LLM call β€” extracts all HPI/ROS fields AND generates next question
39
+ When all fields complete, builds ClinicalBrief inline (no extra LLM call)
40
 
41
+ ## Deployment on Hugging Face Spaces
 
 
 
42
 
43
+ This repo is configured as a **Docker SDK Space**. On every push:
 
44
 
45
+ 1. Docker image builds β€” Ollama gets installed via official install script
46
+ 2. `startup.sh` starts on container boot: launches Ollama, pulls `qwen2.5:0.5b`, starts FastAPI
47
+ 3. App is live on port 7860
48
 
49
+ ```bash
50
+ # Test the Docker build locally before pushing
51
+ docker build -t clinical-intake .
52
+ docker run -p 7860:7860 clinical-intake
53
  ```
54
 
55
+ ## Local Development
56
 
57
  ```bash
58
+ # Fast mock mode (no model needed, instant responses)
59
+ MOCK_LLM=true uvicorn app.main:app --reload
60
+
61
+ # Real Ollama mode β€” requires Ollama installed at localhost:11434
62
+ ollama serve &
63
+ ollama pull qwen2.5:0.5b
64
+ MOCK_LLM=false uvicorn app.main:app --reload
65
  ```
66
 
67
  ## Usage
app/llm.py CHANGED
@@ -147,73 +147,15 @@ class MockLLM:
147
  return CombinedOutput.model_validate(state)
148
 
149
 
150
- class TransformersLLM:
151
  def __init__(self):
152
- self.model = None
153
- self.tokenizer = None
154
- self.model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
155
- self._load_lock = False
156
-
157
- def _load(self):
158
- if self.model is None and not self._load_lock:
159
- import time
160
- t0 = time.time()
161
- self._load_lock = True
162
- from transformers import AutoModelForCausalLM, AutoTokenizer
163
- import torch
164
- print(f"[LLM] Loading {self.model_name} into memory. This may take 5-30 secs on CPU...")
165
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
166
- # Use float16 β€” halves memory footprint and is ~2x faster than float32 on CPU
167
- dtype = torch.float16
168
- self.model = AutoModelForCausalLM.from_pretrained(
169
- self.model_name,
170
- torch_dtype=dtype,
171
- device_map="cpu",
172
- low_cpu_mem_usage=True,
173
- )
174
- self.model.eval()
175
- print(f"[LLM] Model load complete in {time.time() - t0:.1f} seconds.")
176
-
177
- def _infer(self, messages: list[dict], max_tokens: int = 200) -> str:
178
- """Single shared inference method. Greedy decode for speed."""
179
- import torch
180
- import time
181
-
182
- t0 = time.time()
183
- text = self.tokenizer.apply_chat_template(
184
- messages, tokenize=False, add_generation_prompt=True
185
- )
186
- inputs = self.tokenizer(text, return_tensors="pt")
187
- tok_time = time.time() - t0
188
-
189
- t1 = time.time()
190
- with torch.no_grad():
191
- outputs = self.model.generate(
192
- **inputs,
193
- max_new_tokens=max_tokens,
194
- do_sample=False, # Greedy β€” deterministic and fastest
195
- pad_token_id=self.tokenizer.eos_token_id,
196
- )
197
- gen_time = time.time() - t1
198
-
199
- t2 = time.time()
200
- response = self.tokenizer.decode(
201
- outputs[0][inputs.input_ids.shape[1]:],
202
- skip_special_tokens=True,
203
- )
204
- dec_time = time.time() - t2
205
-
206
- print(f"[LLM Timing] Tokens generated: {outputs.shape[1] - inputs.input_ids.shape[1]} | "
207
- f"Tokenize: {tok_time:.3f}s | Infer: {gen_time:.1f}s | Decode: {dec_time:.3f}s")
208
- return response.strip()
209
 
210
  def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
211
  """
212
- Single LLM call that BOTH extracts clinical data AND generates the next reply.
213
- This halves latency vs. running extractor + conversationalist separately.
214
  """
215
- self._load()
216
-
217
  prompt = (
218
  f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
219
  f"FULL CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
@@ -221,16 +163,37 @@ class TransformersLLM:
221
  "and generate exactly ONE empathetic follow-up question for whatever is still missing. "
222
  "Return ONLY the JSON object, no other text."
223
  )
224
- messages = [
225
- {"role": "system", "content": COMBINED_SYSTEM_PROMPT},
226
- {"role": "user", "content": prompt},
227
- ]
228
 
229
  import time
 
 
230
  t_start = time.time()
231
- print("[LLM] Starting inference call...")
232
- raw = self._infer(messages, max_tokens=200)
233
- print(f"[LLM] Inference completed in {time.time() - t_start:.1f} seconds total.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  # Parse JSON robustly
236
  json_str = raw
@@ -239,7 +202,6 @@ class TransformersLLM:
239
  elif "```" in json_str:
240
  json_str = json_str.split("```", 1)[1].split("```")[0]
241
 
242
- # Find first { ... } block
243
  start = json_str.find("{")
244
  end = json_str.rfind("}") + 1
245
  if start != -1 and end > start:
@@ -249,8 +211,7 @@ class TransformersLLM:
249
  parsed = json.loads(json_str)
250
  return CombinedOutput.model_validate(parsed)
251
  except Exception as e:
252
- print(f"[LLM] JSON parse error: {e}\nRaw output: {raw[:300]}")
253
- # Return current state + error reply β€” never crash
254
  try:
255
  base = CombinedOutput.model_validate_json(current_json)
256
  base.reply = "Could you please repeat that? I want to make sure I understood correctly."
@@ -265,5 +226,5 @@ def get_llm():
265
  global _llm_instance
266
  if _llm_instance is None:
267
  mock_mode = os.environ.get("MOCK_LLM", "true").lower() == "true"
268
- _llm_instance = MockLLM() if mock_mode else TransformersLLM()
269
  return _llm_instance
 
147
  return CombinedOutput.model_validate(state)
148
 
149
 
150
+ class OllamaLLM:
151
  def __init__(self):
152
+ self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
153
+ self.api_url = "http://localhost:11434/api/generate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
156
  """
157
+ Calls the local Ollama instance. Requires Ollama to be running.
 
158
  """
 
 
159
  prompt = (
160
  f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
161
  f"FULL CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
 
163
  "and generate exactly ONE empathetic follow-up question for whatever is still missing. "
164
  "Return ONLY the JSON object, no other text."
165
  )
166
+
167
+ full_prompt = f"System: {COMBINED_SYSTEM_PROMPT}\nUser: {prompt}"
 
 
168
 
169
  import time
170
+ import requests
171
+
172
  t_start = time.time()
173
+ print(f"[Ollama] Starting inference for model '{self.model_name}'...")
174
+
175
+ payload = {
176
+ "model": self.model_name,
177
+ "prompt": full_prompt,
178
+ "format": "json",
179
+ "stream": False,
180
+ "options": {
181
+ "temperature": 0.0,
182
+ "num_predict": 250
183
+ }
184
+ }
185
+
186
+ try:
187
+ response = requests.post(self.api_url, json=payload, timeout=60)
188
+ response.raise_for_status()
189
+ data = response.json()
190
+ raw = data.get("response", "")
191
+ except Exception as e:
192
+ print(f"[Ollama] ERROR calling local Ollama API: {e}")
193
+ print("[Ollama] Make sure Ollama is installed and running, and the model is downloaded!")
194
+ return CombinedOutput.model_validate_json(current_json)
195
+
196
+ print(f"[Ollama] Inference completed in {time.time() - t_start:.2f}s total.")
197
 
198
  # Parse JSON robustly
199
  json_str = raw
 
202
  elif "```" in json_str:
203
  json_str = json_str.split("```", 1)[1].split("```")[0]
204
 
 
205
  start = json_str.find("{")
206
  end = json_str.rfind("}") + 1
207
  if start != -1 and end > start:
 
211
  parsed = json.loads(json_str)
212
  return CombinedOutput.model_validate(parsed)
213
  except Exception as e:
214
+ print(f"[Ollama] JSON parse error: {e}\nRaw output: {raw[:300]}")
 
215
  try:
216
  base = CombinedOutput.model_validate_json(current_json)
217
  base.reply = "Could you please repeat that? I want to make sure I understood correctly."
 
226
  global _llm_instance
227
  if _llm_instance is None:
228
  mock_mode = os.environ.get("MOCK_LLM", "true").lower() == "true"
229
+ _llm_instance = MockLLM() if mock_mode else OllamaLLM()
230
  return _llm_instance
requirements.txt CHANGED
@@ -1,11 +1,9 @@
1
  langgraph
2
  fastapi
3
- uvicorn
4
  pydantic
 
5
  pytest
6
  httpx
7
  pytest-asyncio
8
- aiofiles
9
- transformers
10
- huggingface_hub
11
- accelerate
 
1
  langgraph
2
  fastapi
3
+ uvicorn[standard]
4
  pydantic
5
+ requests
6
  pytest
7
  httpx
8
  pytest-asyncio
9
+ aiofiles
 
 
 
startup.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ MODEL="${MODEL_NAME:-qwen2.5:0.5b}"
5
+ OLLAMA_URL="http://localhost:11434"
6
+
7
+ echo "======================================"
8
+ echo " Clinical Intake Agent - Startup"
9
+ echo "======================================"
10
+
11
+ # ── Step 1: Start Ollama in the background ──────────────────────────────────
12
+ echo "[startup] Starting Ollama server..."
13
+ ollama serve &
14
+ OLLAMA_PID=$!
15
+
16
+ # ── Step 2: Wait until Ollama is responsive ─────────────────────────────────
17
+ echo "[startup] Waiting for Ollama to be ready..."
18
+ MAX_WAIT=30
19
+ WAITED=0
20
+ until curl -sf "${OLLAMA_URL}/api/tags" > /dev/null 2>&1; do
21
+ sleep 1
22
+ WAITED=$((WAITED + 1))
23
+ if [ "$WAITED" -ge "$MAX_WAIT" ]; then
24
+ echo "[startup] ERROR: Ollama did not start within ${MAX_WAIT}s. Aborting."
25
+ exit 1
26
+ fi
27
+ done
28
+ echo "[startup] Ollama is ready! (waited ${WAITED}s)"
29
+
30
+ # ── Step 3: Pull / verify model ─────────────────────────────────────────────
31
+ echo "[startup] Pulling model '${MODEL}' (skipped if already cached)..."
32
+ ollama pull "${MODEL}"
33
+ echo "[startup] Model '${MODEL}' is ready."
34
+
35
+ # ── Step 4: Start FastAPI application ────────────────────────────────────────
36
+ echo "[startup] Launching FastAPI on port 7860..."
37
+ exec uvicorn app.main:app --host 0.0.0.0 --port 7860 --workers 1