| """ |
| src/graphs/graph_factory.py |
| |
| Factory function that builds a LangGraph StateGraph for any medical specialty. |
| Architecture: __start__ β information_retrieval β interpretation_generation β __end__ |
| """ |
| from __future__ import annotations |
|
|
| import json |
| from typing import Any, Dict, List |
|
|
| from langchain_core.messages import AIMessage, HumanMessage |
| from langgraph.graph import END, StateGraph |
|
|
| from src.llms.groqllm import GroqLLM |
| from src.states.expert_subgraph_state import ExpertSubgraphState |
| from src.utils.prompts import get_expert_prompt |
| from src.utils.utils import EXPERT_TOOLS, get_today_str |
| from src.utils.vector_store import get_history, save_interpretation |
| from src.utils.db import insert_interpretation |
|
|
|
|
| def _augment_with_ml_models(specialty: str, telemetry: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Attach structured-model outputs (ECGFounder, respiratory CNN, fetal |
| health, retinal classifiers, β¦) to the tool_results dict when the |
| relevant adapter has weights loaded and the telemetry snapshot |
| carries the inputs the adapter needs. |
| |
| Returns a dict of extra {tool_name: json_string} entries β empty when |
| no adapters are ready or no relevant inputs are present. |
| |
| Each branch gates on (a) specialty match, (b) adapter.is_loaded, and |
| (c) snapshot contains the right modality data. Each adapter call is |
| wrapped in try/except so a single broken weight file can never break |
| the whole graph run. |
| """ |
| extras: Dict[str, str] = {} |
| telemetry = telemetry or {} |
| waveform = telemetry.get("waveform") or {} |
|
|
| import numpy as np |
|
|
| |
| if "Cardiology" in specialty: |
| if waveform: |
| try: |
| from src.ml.ecgfounder_adapter import get_ecgfounder |
| adapter = get_ecgfounder() |
| if adapter.is_loaded and waveform.get("ecg_lead2"): |
| signal = np.asarray(waveform["ecg_lead2"], dtype=float) |
| pred = adapter.classify(signal, fs=waveform.get("fs", 40)) |
| if pred is not None: |
| extras["ecgfounder_classification"] = json.dumps(pred) |
| except Exception: |
| pass |
| |
| try: |
| from src.ml.ecg_arrhythmia_adapter import get_ecg_arrhythmia |
| adapter = get_ecg_arrhythmia() |
| if adapter.is_loaded: |
| pred = adapter.predict_dict(telemetry) |
| if pred is not None: |
| extras["ecg_arrhythmia_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
| |
| try: |
| from src.ml.cardiac_age_adapter import get_cardiac_age |
| adapter = get_cardiac_age() |
| if adapter.is_loaded: |
| pred = adapter.predict_with_chrono(telemetry) |
| if pred is not None: |
| extras["cardiac_age_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
|
|
| |
| if "Pulmonary" in specialty or "Respiratory" in specialty: |
| if waveform: |
| try: |
| from src.ml.pulmonary_classifier import get_pulmonary_classifier |
| clf = get_pulmonary_classifier() |
| if clf.is_loaded and waveform.get("audio"): |
| audio = np.asarray(waveform["audio"], dtype=float) |
| pred = clf.predict(audio, fs=waveform.get("fs", 40)) |
| if pred is not None: |
| extras["respiratory_cnn_classification"] = json.dumps({ |
| "label": pred.label, |
| "probs": pred.probs, |
| "confidence": pred.confidence, |
| }) |
| except Exception: |
| pass |
| |
| try: |
| from src.ml.lung_sound_adapter import get_lung_sound |
| adapter = get_lung_sound() |
| if adapter.is_loaded: |
| pred = adapter.predict_dict(telemetry) |
| if pred is not None: |
| extras["lung_sound_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
|
|
| |
| if "Neurology" in specialty: |
| |
| |
| try: |
| from src.ml.parkinson_screener_adapter import get_parkinson_screener |
| adapter = get_parkinson_screener() |
| if adapter.is_loaded: |
| pred = adapter.predict_dict(telemetry) |
| if pred is not None: |
| extras["parkinson_screener_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
|
|
| |
| if "Dermatology" in specialty: |
| imaging = telemetry.get("imaging") or {} |
| skin_path = (imaging.get("skin") or {}).get("image_path") |
| demographics = telemetry.get("patient") or {} |
| try: |
| from src.ml.skin_disease_adapter import get_skin_disease |
| adapter = get_skin_disease() |
| if adapter.is_loaded: |
| pred = adapter.predict_with_image(demographics, image_path=skin_path) |
| if pred is not None: |
| extras["skin_disease_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
|
|
| |
| if "Obstetrics" in specialty or "Gynecology" in specialty: |
| fetal_block = telemetry.get("fetal") or {} |
| if fetal_block: |
| try: |
| from src.ml.fetal_health_adapter import get_fetal_health |
| adapter = get_fetal_health() |
| if adapter.is_loaded: |
| pred = adapter.predict_dict(fetal_block) |
| if pred is not None: |
| extras["fetal_health_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
| try: |
| from src.ml.preterm_labour_adapter import get_preterm_labour |
| adapter = get_preterm_labour() |
| if adapter.is_loaded: |
| pred = adapter.predict_dict({ |
| "fetal": fetal_block, |
| "patient": telemetry.get("patient") or {}, |
| }) |
| if pred is not None: |
| extras["preterm_labour_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
|
|
| |
| if "Ocular" in specialty or "Ocul" in specialty: |
| imaging = telemetry.get("imaging") or {} |
| retinal_path = (imaging.get("retinal") or {}).get("image_path") |
| demographics = telemetry.get("patient") or {} |
| if demographics or retinal_path: |
| try: |
| from src.ml.retinal_disease_adapter import get_retinal_disease |
| adapter = get_retinal_disease() |
| if adapter.is_loaded: |
| pred = adapter.predict_with_image(demographics, image_path=retinal_path) |
| if pred is not None: |
| extras["retinal_disease_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
| try: |
| from src.ml.retinal_age_adapter import get_retinal_age |
| adapter = get_retinal_age() |
| if adapter.is_loaded: |
| pred = adapter.predict_with_image(demographics, image_path=retinal_path) |
| if pred is not None: |
| extras["retinal_age_prediction"] = json.dumps(pred) |
| except Exception: |
| pass |
|
|
| return extras |
|
|
| def _get_model(specialty: str = None): |
| |
| return GroqLLM(specialty).get_llm() |
|
|
|
|
| |
|
|
|
|
| def _make_information_retrieval(specialty: str): |
| """ |
| Creates the information_retrieval node for a given specialty. |
| |
| Behaviour: |
| β’ If `state["sensor_telemetry"]` is present (mobile attached a real |
| snapshot β the common case in production), the mock retrieval |
| tools are skipped. The LLM still sees the real values via the |
| `telemetry_context` block of the system prompt; running the mock |
| tools alongside would inject fabricated values that contradict |
| the live data and confuse the assessment. |
| β’ If no snapshot is present (e.g. background_agent_runner_loop on |
| a quiet system), the mock tools fire so the LLM has *something* |
| plausible to interpret instead of empty input. |
| β’ ML-model adapter outputs (graph_factory._augment_with_ml_models) |
| are appended in either case β they're real predictions when the |
| adapters have weights, and gracefully no-op otherwise. |
| """ |
| tools = EXPERT_TOOLS.get(specialty, []) |
|
|
| def information_retrieval(state: ExpertSubgraphState) -> Dict[str, Any]: |
| tool_results: Dict[str, str] = {} |
| telemetry = state.get("sensor_telemetry") or {} |
| has_telemetry = bool(telemetry) and ( |
| bool(telemetry.get("vitals")) |
| or bool(telemetry.get("waveform")) |
| or bool(telemetry.get("imaging")) |
| or bool(telemetry.get("fetal")) |
| ) |
|
|
| if has_telemetry: |
| |
| |
| tool_results["live_telemetry"] = ( |
| "Live telemetry attached to this request β see " |
| "## LIVE TELEMETRY CONTEXT in the prompt below. Trust " |
| "those numeric values over any other defaults." |
| ) |
| else: |
| |
| |
| for tool_fn in tools: |
| try: |
| result = tool_fn.invoke({}) |
| tool_results[tool_fn.name] = result |
| except Exception as e: |
| tool_results[tool_fn.name] = f"ERROR: {e}" |
|
|
| |
| |
| try: |
| ml_extras = _augment_with_ml_models(specialty, telemetry) |
| tool_results.update(ml_extras) |
| except Exception: |
| pass |
|
|
| return { |
| "tool_results": tool_results, |
| "messages": [ |
| AIMessage( |
| content=f"[{specialty}] Information retrieval complete. " |
| f"{len(tool_results)} tools executed." |
| ) |
| ], |
| "traces": [ |
| { |
| "step": "information_retrieval", |
| "specialty": specialty, |
| "tools_called": list(tool_results.keys()), |
| "snapshot_present": has_telemetry, |
| } |
| ], |
| } |
|
|
| return information_retrieval |
|
|
|
|
| def _make_interpretation_generation(specialty: str): |
| """ |
| Creates the interpretation_generation node for a given specialty. |
| Loads domain knowledge + session history, calls LLM with comprehensive |
| system prompt, saves output summary to vector store. |
| """ |
|
|
| def interpretation_generation(state: ExpertSubgraphState) -> Dict[str, Any]: |
| domain = state.get("expert_domain") or specialty |
| tool_results = state.get("tool_results") or {} |
| shared = state.get("shared_context") or {} |
| telemetry = state.get("sensor_telemetry") or {} |
|
|
| |
| |
| |
| |
| |
| |
| user_message = "" |
| msgs = state.get("messages") or [] |
| for msg in reversed(msgs): |
| if isinstance(msg, dict): |
| role = msg.get("role") or msg.get("type") or "" |
| content = msg.get("content", "") |
| else: |
| |
| role = getattr(msg, "type", "") or msg.__class__.__name__.lower() |
| content = getattr(msg, "content", "") |
| if role in ("user", "human", "humanmessage") and content: |
| user_message = content |
| break |
| user_message = (user_message or "").strip() |
|
|
| |
| tool_results_str = "\n".join( |
| [f"### {name}\n```json\n{result}\n```" for name, result in tool_results.items()] |
| ) |
|
|
| |
| patient_id = shared.get("patient_id", "default_patient") |
|
|
| |
| |
| |
| rag_query = user_message or f"{domain} assessment current session" |
| history = get_history( |
| specialty=specialty, |
| patient_id=patient_id, |
| query=rag_query, |
| k=5, |
| ) |
|
|
| |
| |
| patient_profile = ( |
| shared.get("patient_profile") |
| or shared.get("patient_base") |
| or "" |
| ) |
| if isinstance(patient_profile, dict): |
| patient_profile = json.dumps(patient_profile, indent=2) |
|
|
| |
| |
| telemetry_str = json.dumps(telemetry, indent=2) if telemetry else "" |
|
|
| |
| system_prompt = get_expert_prompt( |
| specialty=specialty, |
| tool_results=tool_results_str, |
| history=history if history else None, |
| patient_profile=str(patient_profile), |
| telemetry_context=telemetry_str, |
| ) |
|
|
| |
| |
| if user_message: |
| system_prompt += f""" |
| |
| ## USER QUESTION |
| The patient (or clinician) has asked the following. Answer it directly, |
| grounding every claim in the live telemetry, tool results, patient |
| profile, and clinical reference knowledge above. The 'finding' field |
| of your JSON output must be a substantive answer to THIS specific |
| question β not a generic assessment. |
| |
| QUESTION: {user_message} |
| """ |
|
|
| |
| try: |
| resp = _get_model(specialty).invoke( |
| [HumanMessage(content=system_prompt)] |
| ) |
| content = getattr(resp, "content", None) or str(resp) |
| except Exception as e: |
| return { |
| "error_message": f"{type(e).__name__}: {e}", |
| "final_expert_analysis": { |
| "expert_domain": domain, |
| "clinical_findings": f"Interpretation failed: {e}", |
| "anomaly_detected": False, |
| "confidence_score": 0.0, |
| }, |
| "messages": [ |
| AIMessage(content=f"[{domain}] Interpretation generation failed.") |
| ], |
| } |
|
|
| |
| try: |
| import re |
| content_clean = content.strip() |
| |
| json_match = re.search(r'\{.*\}', content_clean, re.DOTALL) |
| |
| clean = json_match.group(0) if json_match else content_clean |
| parsed = json.loads(clean) |
| except (json.JSONDecodeError, Exception): |
| parsed = { |
| "expert": domain, |
| "finding": content.strip()[:1000] + "..." if len(content) > 1000 else content.strip(), |
| "severity": "unknown", |
| "severity_score": 0.0, |
| "confidence": 0.5, |
| } |
|
|
| |
| analysis = { |
| "expert_domain": domain, |
| "clinical_findings": parsed.get("finding", content), |
| "severity": parsed.get("severity", "unknown"), |
| "severity_score": parsed.get("severity_score", 0.0), |
| "key_observations": parsed.get("key_observations", []), |
| "recommendations": parsed.get("recommendations", []), |
| "confidence_score": parsed.get("confidence", 0.6), |
| "anomaly_detected": parsed.get("severity_score", 0) >= 5, |
| "generated_at": get_today_str(), |
| } |
|
|
| |
| summary = ( |
| f"[{domain}] Severity: {analysis['severity']} " |
| f"(score: {analysis['severity_score']}). " |
| f"Findings: {analysis['clinical_findings'][:300]}" |
| ) |
| try: |
| save_interpretation( |
| specialty=specialty, |
| patient_id=patient_id, |
| interpretation=content, |
| summary=summary, |
| metadata={ |
| "severity": analysis["severity"], |
| "severity_score": analysis["severity_score"], |
| "generated_at": analysis["generated_at"], |
| }, |
| ) |
| except Exception: |
| pass |
| |
| try: |
| insert_interpretation( |
| specialty=specialty, |
| findings=analysis["clinical_findings"], |
| severity=analysis["severity"], |
| severity_score=analysis["severity_score"] |
| ) |
| except Exception: |
| pass |
|
|
| return { |
| "final_expert_analysis": analysis, |
| "messages": [ |
| AIMessage(content=analysis['clinical_findings'].strip()) |
| ], |
| "traces": [ |
| { |
| "step": "interpretation_generation", |
| "specialty": specialty, |
| "severity": analysis["severity"], |
| "severity_score": analysis["severity_score"], |
| } |
| ], |
| } |
|
|
| return interpretation_generation |
|
|
|
|
| |
|
|
|
|
| def build_expert_graph(specialty: str) -> StateGraph: |
| """ |
| Build a LangGraph StateGraph for a given medical specialty. |
| |
| Architecture: |
| __start__ β information_retrieval β interpretation_generation β __end__ |
| |
| Args: |
| specialty: Key from EXPERT_TOOLS / EXPERT_SYSTEM_PROMPTS |
| |
| Returns: |
| Uncompiled StateGraph ready for .compile() |
| """ |
| builder = StateGraph(ExpertSubgraphState) |
|
|
| |
| info_retrieval = _make_information_retrieval(specialty) |
| interpretation = _make_interpretation_generation(specialty) |
|
|
| |
| builder.add_node("information_retrieval", info_retrieval) |
| builder.add_node("interpretation_generation", interpretation) |
|
|
| |
| builder.add_edge("__start__", "information_retrieval") |
| builder.add_edge("information_retrieval", "interpretation_generation") |
| builder.add_edge("interpretation_generation", END) |
|
|
| return builder |
|
|