Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from pydantic import BaseModel | |
| SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview. | |
| YOUR WORKFLOW (follow this order): | |
| 1. INTAKE: Identify the patient's chief complaint (main reason for visit). | |
| 2. HPI (History of Present Illness): Collect these fields ONE AT A TIME, in order: | |
| - onset: when the symptom started | |
| - location: where in the body | |
| - duration: how long it has lasted | |
| - character: quality (sharp, dull, pressure, burning, etc.) | |
| - severity: how bad on a scale of 1-10 | |
| - aggravating: what makes it worse | |
| - relieving: what makes it better | |
| 3. ROS (Review of Systems): Screen 3 body systems RELEVANT to the chief complaint. | |
| Examples of relevant systems: | |
| - Leg/knee/joint pain β musculoskeletal, neurological, vascular | |
| - Chest pain β cardiac, respiratory, gi | |
| - Headache β neurological, ophthalmologic, ent | |
| - Abdominal pain β gi, genitourinary, musculoskeletal | |
| - Back pain β musculoskeletal, neurological, genitourinary | |
| 4. DONE: When all HPI fields AND 3 ROS systems are filled, set reply to "Your clinical summary is ready. Please wait for the doctor." | |
| CRITICAL RULES: | |
| - NEVER re-ask a field that is already filled (marked β in the status). | |
| - Ask exactly ONE question per turn about the FIRST missing item. | |
| - For HPI: accept any answer the patient gives, even vague ones like "moderate" or "not sure". | |
| - For ROS: ALWAYS add the system to BOTH "ros" and "ros_asked" β even for negative answers. | |
| - Positive finding: "cardiac": ["palpitations present"] | |
| - Negative finding: "respiratory": ["no shortness of breath"] | |
| - Denied: "gi": ["denied nausea and vomiting"] | |
| A "no" is still a valid clinical finding. Never leave a ros system in ros_asked but absent from ros. | |
| - Do NOT ask emotional/psychological questions β stick to physical symptoms. | |
| - All string fields must be strings, not arrays. | |
| - Output ONLY valid JSON, no extra text. | |
| OUTPUT FORMAT: | |
| { | |
| "chief_complaint": "..." or null, | |
| "onset": "..." or null, | |
| "location": "..." or null, | |
| "duration": "..." or null, | |
| "character": "..." or null, | |
| "severity": "..." or null, | |
| "aggravating": "..." or null, | |
| "relieving": "..." or null, | |
| "ros": {"system_name": ["finding1", "finding2"], ...}, | |
| "ros_asked": ["system_name1", "system_name2"], | |
| "emergency": false, | |
| "reply": "Your single question" | |
| }""" | |
| HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"] | |
| ROS_REQUIRED = 3 | |
| def build_state_context(current_json: str) -> str: | |
| try: | |
| state = json.loads(current_json) | |
| except Exception: | |
| state = {} | |
| lines = ["FIELD STATUS:"] | |
| cc = state.get("chief_complaint") | |
| if cc: | |
| lines.append(f' β chief_complaint: "{cc}"') | |
| else: | |
| lines.append(" β chief_complaint: MISSING β ask what brings them in") | |
| for field in HPI_FIELDS: | |
| val = state.get(field) | |
| if val: | |
| lines.append(f' β {field}: "{val}"') | |
| else: | |
| lines.append(f" β {field}: MISSING") | |
| ros = state.get("ros", {}) | |
| ros_asked = state.get("ros_asked", []) | |
| if ros: | |
| for sys_name, findings in ros.items(): | |
| lines.append(f' β ros.{sys_name}: {findings}') | |
| ros_remaining = ROS_REQUIRED - len(ros) | |
| if ros_remaining > 0: | |
| lines.append(f" β ros: {ros_remaining} more system(s) needed") | |
| if ros_asked: | |
| lines.append(f" βΉοΈ Already asked about: {', '.join(ros_asked)} β DO NOT ask about these again") | |
| else: | |
| lines.append(f" β ros: all {ROS_REQUIRED} systems collected") | |
| if not cc: | |
| phase = "INTAKE" | |
| lines.append(f"\nCURRENT PHASE: {phase}") | |
| elif any(not state.get(f) for f in HPI_FIELDS): | |
| phase = "HPI" | |
| first_missing = next(f for f in HPI_FIELDS if not state.get(f)) | |
| lines.append(f"\nCURRENT PHASE: {phase} β ask about '{first_missing}' next") | |
| elif ros_remaining > 0: | |
| phase = "ROS" | |
| lines.append(f"\nCURRENT PHASE: {phase} β ask about the next body system relevant to '{cc}'") | |
| lines.append(f" β οΈ IMPORTANT: Store BOTH positive AND negative ROS findings in 'ros' dict.") | |
| lines.append(f" β οΈ A patient saying 'no' means: ros[\"system\"] = [\"no [symptom]\"]") | |
| else: | |
| phase = "DONE" | |
| lines.append(f"\nCURRENT PHASE: {phase} β all data collected") | |
| return "\n".join(lines) | |
| class CombinedOutput(BaseModel): | |
| chief_complaint: str | None = None | |
| onset: str | None = None | |
| location: str | None = None | |
| duration: str | None = None | |
| character: str | None = None | |
| severity: str | None = None | |
| aggravating: str | None = None | |
| relieving: str | None = None | |
| ros: dict[str, list[str]] = {} | |
| ros_asked: list[str] = [] | |
| emergency: bool = False | |
| reply: str = "" | |
| class MockLLM: | |
| """Minimal mock for testing β deterministic field walker.""" | |
| def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput: | |
| try: | |
| state = json.loads(current_json) | |
| except Exception: | |
| state = {} | |
| lines = transcript.strip().split("\n") | |
| last_patient_msg = "" | |
| for line in reversed(lines): | |
| if line.startswith("Patient:"): | |
| last_patient_msg = line.replace("Patient:", "").strip() | |
| break | |
| ros_systems = ["cardiac", "respiratory", "gi"] | |
| if stage == "intake": | |
| if last_patient_msg and not state.get("chief_complaint"): | |
| # Strip greeting words | |
| greetings = {"hello", "hi", "hey", "ok", "okay", "start", "yes", "sure"} | |
| if last_patient_msg.lower() not in greetings and len(last_patient_msg) > 4: | |
| state["chief_complaint"] = last_patient_msg | |
| state["reply"] = ( | |
| "What brings you in today?" | |
| if not state.get("chief_complaint") | |
| else f"When did the {state['chief_complaint']} start?" | |
| ) | |
| elif stage == "hpi": | |
| for field in HPI_FIELDS: | |
| if not state.get(field): | |
| if last_patient_msg: | |
| state[field] = last_patient_msg | |
| break | |
| for field in HPI_FIELDS: | |
| if not state.get(field): | |
| labels = { | |
| "onset": "when it started", | |
| "location": "where you feel it", | |
| "duration": "how long it's lasted", | |
| "character": "what it feels like", | |
| "severity": "how severe it is (1-10)", | |
| "aggravating": "what makes it worse", | |
| "relieving": "what makes it better", | |
| } | |
| state["reply"] = f"Can you tell me {labels.get(field, field)}?" | |
| break | |
| else: | |
| state["reply"] = "Thank you, let me ask about other symptoms." | |
| elif stage == "ros": | |
| ros = state.get("ros", {}) | |
| ros_asked = state.get("ros_asked", []) | |
| # Detect emergency keywords | |
| if any(k in last_patient_msg.lower() for k in ["crushing", "can't breathe", "dying"]): | |
| state["emergency"] = True | |
| # Store last patient message into the first un-asked system | |
| for sys_name in ros_systems: | |
| if sys_name not in ros: | |
| if last_patient_msg: | |
| ros[sys_name] = [last_patient_msg] | |
| state["ros"] = ros | |
| if sys_name not in ros_asked: | |
| ros_asked.append(sys_name) | |
| state["ros_asked"] = ros_asked | |
| break | |
| # Ask about the next un-asked system | |
| for sys_name in ros_systems: | |
| if sys_name not in ros: | |
| state["reply"] = f"Any {sys_name} symptoms?" | |
| break | |
| else: | |
| state["reply"] = "Thank you β I have everything I need." | |
| return CombinedOutput.model_validate(state) | |
| class OllamaLLM: | |
| def __init__(self): | |
| self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b") | |
| self.api_url = "http://localhost:11434/api/chat" | |
| def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput: | |
| state_context = build_state_context(current_json) | |
| prompt = ( | |
| f"{state_context}\n\n" | |
| f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n" | |
| f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n" | |
| "TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. " | |
| "Then ask exactly ONE question about the FIRST missing item shown above. " | |
| "For ROS: if the patient answers about a system (even 'no'), add it to BOTH ros AND ros_asked. " | |
| "Return ONLY the updated JSON object." | |
| ) | |
| import time | |
| import requests | |
| t_start = time.time() | |
| print(f"[Ollama] Starting inference for model '{self.model_name}'...") | |
| print(f"[Ollama] State context:\n{state_context}") | |
| payload = { | |
| "model": self.model_name, | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| "format": "json", | |
| "stream": False, | |
| "options": { | |
| "temperature": 0.0, | |
| "num_predict": 400 | |
| } | |
| } | |
| try: | |
| response = requests.post(self.api_url, json=payload, timeout=60) | |
| response.raise_for_status() | |
| data = response.json() | |
| raw = data.get("message", {}).get("content", "").strip() | |
| except Exception as e: | |
| print(f"[Ollama] ERROR calling Ollama API: {e}") | |
| return CombinedOutput.model_validate_json(current_json) | |
| print(f"[Ollama] Inference completed in {time.time() - t_start:.2f}s total.") | |
| # Strip markdown fences | |
| json_str = raw | |
| if "```json" in json_str: | |
| json_str = json_str.split("```json", 1).split("```")[1] | |
| elif "```" in json_str: | |
| json_str = json_str.split("```", 1)[3].split("```")[0] | |
| start = json_str.find("{") | |
| end = json_str.rfind("}") + 1 | |
| if start != -1 and end > start: | |
| json_str = json_str[start:end] | |
| try: | |
| parsed = json.loads(json_str) | |
| # ββ Coerce all HPI string fields: listβstr, empty/nullβNone ββ | |
| for field in ["chief_complaint", "onset", "location", "duration", | |
| "character", "severity", "aggravating", "relieving"]: | |
| v = parsed.get(field) | |
| if isinstance(v, list): | |
| # e.g. ["Walking"] β "Walking" | |
| parsed[field] = " ".join(str(x) for x in v) if v else None | |
| elif v is not None and str(v).strip() in ("", "null"): | |
| parsed[field] = None | |
| result = CombinedOutput.model_validate(parsed) | |
| except Exception as e: | |
| print(f"[Ollama] JSON parse error: {e}\nRaw output: {raw[:300]}") | |
| try: | |
| result = CombinedOutput.model_validate_json(current_json) | |
| result = result.model_copy(update={"reply": "Could you please repeat that? I want to make sure I understood correctly."}) | |
| return result | |
| except Exception: | |
| return CombinedOutput(reply="Could you please repeat that?") | |
| # ββ Post-process: normalize ros_asked β ros ββββββββββββββββββββββ | |
| # If LLM added a system to ros_asked but not ros (e.g. for "no" answers), | |
| # capture the last patient message as the finding for that system. | |
| if result.ros_asked: | |
| last_user = "" | |
| for line in reversed(transcript.strip().split("\n")): | |
| if line.startswith("Patient:"): | |
| last_user = line.replace("Patient:", "").strip() | |
| break | |
| updated_ros = dict(result.ros) | |
| changed = False | |
| for asked_sys in result.ros_asked: | |
| if asked_sys not in updated_ros: | |
| updated_ros[asked_sys] = [last_user] if last_user else ["no symptoms reported"] | |
| print(f"[ROSNorm] Filled ros['{asked_sys}'] from patient message: '{last_user[:40]}'") | |
| changed = True | |
| if changed: | |
| result = result.model_copy(update={"ros": updated_ros}) | |
| print(f"[Ollama] Parsed result β stage will be recomputed in graph.") | |
| return result | |
| _llm_instance = None | |
| def get_llm(): | |
| global _llm_instance | |
| if _llm_instance is None: | |
| mock_mode = os.environ.get("MOCK_LLM", "true").lower() == "true" | |
| _llm_instance = MockLLM() if mock_mode else OllamaLLM() | |
| return _llm_instance |