Spaces:
Sleeping
Sleeping
chore: codebase audit and fixes (ruff, mypy, pytest)
Browse files- Fixed pytest collection hang by adding root conftest.py and testpaths in pytest.ini
- Fixed mypy missing keys for GuildState TypedDict definitions
- Fixed ruff whitespace formatting and unhandled B904 exceptions
- Added skip markers for unmocked LLM API calls in integration tests
- Removed redundant debug script logging and trace files
This view is limited to 50 files because it contains too many changes. Β
See raw diff
- airflow/dags/ingest_pdfs.py +5 -1
- alembic/env.py +1 -3
- alembic/versions/001_initial.py +52 -51
- api/app/__init__.py +1 -0
- api/app/main.py +13 -24
- api/app/routes/analyze.py +34 -45
- api/app/routes/biomarkers.py +8 -22
- api/app/routes/health.py +3 -3
- api/app/services/extraction.py +12 -19
- api/app/services/ragbot.py +19 -25
- archive/evolution/__init__.py +11 -11
- archive/evolution/director.py +79 -101
- archive/evolution/pareto.py +37 -32
- archive/tests/test_evolution_loop.py +32 -52
- archive/tests/test_evolution_quick.py +2 -7
- conftest.py +1 -0
- huggingface/app.py +57 -60
- pytest.ini +2 -0
- scripts/chat.py +46 -51
- scripts/monitor_test.py +2 -1
- scripts/setup_embeddings.py +13 -13
- scripts/test_chat_demo.py +6 -6
- scripts/test_extraction.py +5 -5
- src/agents/biomarker_analyzer.py +18 -24
- src/agents/biomarker_linker.py +33 -68
- src/agents/clinical_guidelines.py +80 -88
- src/agents/confidence_assessor.py +36 -79
- src/agents/disease_explainer.py +60 -62
- src/agents/response_synthesizer.py +84 -79
- src/biomarker_normalization.py +0 -14
- src/biomarker_validator.py +76 -67
- src/config.py +14 -31
- src/database.py +1 -0
- src/evaluation/__init__.py +8 -8
- src/evaluation/evaluators.py +106 -110
- src/exceptions.py +12 -0
- src/gradio_app.py +7 -11
- src/llm_config.py +21 -53
- src/main.py +11 -0
- src/middlewares.py +19 -5
- src/pdf_processor.py +35 -61
- src/repositories/analysis.py +2 -11
- src/repositories/document.py +2 -11
- src/routers/analyze.py +8 -22
- src/routers/ask.py +7 -12
- src/routers/health.py +12 -2
- src/schemas/schemas.py +15 -7
- src/services/agents/context.py +6 -6
- src/services/agents/nodes/retrieve_node.py +2 -8
- src/services/agents/state.py +4 -4
airflow/dags/ingest_pdfs.py
CHANGED
|
@@ -38,7 +38,11 @@ def _ingest_pdfs(**kwargs):
|
|
| 38 |
parser = make_pdf_parser_service()
|
| 39 |
embedding_svc = make_embedding_service()
|
| 40 |
os_client = make_opensearch_client()
|
| 41 |
-
chunker = MedicalTextChunker(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
indexing_svc = IndexingService(chunker, embedding_svc, os_client)
|
| 43 |
|
| 44 |
docs = parser.parse_directory(pdf_dir)
|
|
|
|
| 38 |
parser = make_pdf_parser_service()
|
| 39 |
embedding_svc = make_embedding_service()
|
| 40 |
os_client = make_opensearch_client()
|
| 41 |
+
chunker = MedicalTextChunker(
|
| 42 |
+
target_words=settings.chunking.chunk_size,
|
| 43 |
+
overlap_words=settings.chunking.chunk_overlap,
|
| 44 |
+
min_words=settings.chunking.min_chunk_size,
|
| 45 |
+
)
|
| 46 |
indexing_svc = IndexingService(chunker, embedding_svc, os_client)
|
| 47 |
|
| 48 |
docs = parser.parse_directory(pdf_dir)
|
alembic/env.py
CHANGED
|
@@ -79,9 +79,7 @@ def run_migrations_online() -> None:
|
|
| 79 |
)
|
| 80 |
|
| 81 |
with connectable.connect() as connection:
|
| 82 |
-
context.configure(
|
| 83 |
-
connection=connection, target_metadata=target_metadata
|
| 84 |
-
)
|
| 85 |
|
| 86 |
with context.begin_transaction():
|
| 87 |
context.run_migrations()
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
with connectable.connect() as connection:
|
| 82 |
+
context.configure(connection=connection, target_metadata=target_metadata)
|
|
|
|
|
|
|
| 83 |
|
| 84 |
with context.begin_transaction():
|
| 85 |
context.run_migrations()
|
alembic/versions/001_initial.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
"""initial_tables
|
| 2 |
|
| 3 |
Revision ID: 001
|
| 4 |
-
Revises:
|
| 5 |
Create Date: 2026-02-24 20:58:00.000000
|
| 6 |
|
| 7 |
"""
|
|
|
|
| 8 |
import sqlalchemy as sa
|
| 9 |
|
| 10 |
from alembic import op
|
| 11 |
|
| 12 |
# revision identifiers, used by Alembic.
|
| 13 |
-
revision =
|
| 14 |
down_revision = None
|
| 15 |
branch_labels = None
|
| 16 |
depends_on = None
|
|
@@ -18,64 +19,64 @@ depends_on = None
|
|
| 18 |
|
| 19 |
def upgrade() -> None:
|
| 20 |
op.create_table(
|
| 21 |
-
|
| 22 |
-
sa.Column(
|
| 23 |
-
sa.Column(
|
| 24 |
-
sa.Column(
|
| 25 |
-
sa.Column(
|
| 26 |
-
sa.Column(
|
| 27 |
-
sa.Column(
|
| 28 |
-
sa.Column(
|
| 29 |
-
sa.Column(
|
| 30 |
-
sa.Column(
|
| 31 |
-
sa.Column(
|
| 32 |
-
sa.Column(
|
| 33 |
-
sa.Column(
|
| 34 |
-
sa.Column(
|
| 35 |
-
sa.PrimaryKeyConstraint(
|
| 36 |
)
|
| 37 |
-
op.create_index(op.f(
|
| 38 |
|
| 39 |
op.create_table(
|
| 40 |
-
|
| 41 |
-
sa.Column(
|
| 42 |
-
sa.Column(
|
| 43 |
-
sa.Column(
|
| 44 |
-
sa.Column(
|
| 45 |
-
sa.Column(
|
| 46 |
-
sa.Column(
|
| 47 |
-
sa.Column(
|
| 48 |
-
sa.Column(
|
| 49 |
-
sa.Column(
|
| 50 |
-
sa.Column(
|
| 51 |
-
sa.Column(
|
| 52 |
-
sa.Column(
|
| 53 |
-
sa.Column(
|
| 54 |
-
sa.PrimaryKeyConstraint(
|
| 55 |
-
sa.UniqueConstraint(
|
| 56 |
)
|
| 57 |
-
op.create_index(op.f(
|
| 58 |
|
| 59 |
op.create_table(
|
| 60 |
-
|
| 61 |
-
sa.Column(
|
| 62 |
-
sa.Column(
|
| 63 |
-
sa.Column(
|
| 64 |
-
sa.Column(
|
| 65 |
-
sa.Column(
|
| 66 |
-
sa.Column(
|
| 67 |
-
sa.Column(
|
| 68 |
-
sa.PrimaryKeyConstraint(
|
| 69 |
)
|
| 70 |
-
op.create_index(op.f(
|
| 71 |
|
| 72 |
|
| 73 |
def downgrade() -> None:
|
| 74 |
-
op.drop_index(op.f(
|
| 75 |
-
op.drop_table(
|
| 76 |
|
| 77 |
-
op.drop_index(op.f(
|
| 78 |
-
op.drop_table(
|
| 79 |
|
| 80 |
-
op.drop_index(op.f(
|
| 81 |
-
op.drop_table(
|
|
|
|
| 1 |
"""initial_tables
|
| 2 |
|
| 3 |
Revision ID: 001
|
| 4 |
+
Revises:
|
| 5 |
Create Date: 2026-02-24 20:58:00.000000
|
| 6 |
|
| 7 |
"""
|
| 8 |
+
|
| 9 |
import sqlalchemy as sa
|
| 10 |
|
| 11 |
from alembic import op
|
| 12 |
|
| 13 |
# revision identifiers, used by Alembic.
|
| 14 |
+
revision = "001"
|
| 15 |
down_revision = None
|
| 16 |
branch_labels = None
|
| 17 |
depends_on = None
|
|
|
|
| 19 |
|
| 20 |
def upgrade() -> None:
|
| 21 |
op.create_table(
|
| 22 |
+
"patient_analyses",
|
| 23 |
+
sa.Column("id", sa.String(length=36), nullable=False),
|
| 24 |
+
sa.Column("request_id", sa.String(length=64), nullable=False),
|
| 25 |
+
sa.Column("biomarkers", sa.JSON(), nullable=False),
|
| 26 |
+
sa.Column("patient_context", sa.JSON(), nullable=True),
|
| 27 |
+
sa.Column("predicted_disease", sa.String(length=128), nullable=False),
|
| 28 |
+
sa.Column("confidence", sa.Float(), nullable=False),
|
| 29 |
+
sa.Column("probabilities", sa.JSON(), nullable=True),
|
| 30 |
+
sa.Column("analysis_result", sa.JSON(), nullable=True),
|
| 31 |
+
sa.Column("safety_alerts", sa.JSON(), nullable=True),
|
| 32 |
+
sa.Column("sop_version", sa.String(length=64), nullable=True),
|
| 33 |
+
sa.Column("processing_time_ms", sa.Float(), nullable=False),
|
| 34 |
+
sa.Column("model_provider", sa.String(length=32), nullable=True),
|
| 35 |
+
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
| 36 |
+
sa.PrimaryKeyConstraint("id"),
|
| 37 |
)
|
| 38 |
+
op.create_index(op.f("ix_patient_analyses_request_id"), "patient_analyses", ["request_id"], unique=True)
|
| 39 |
|
| 40 |
op.create_table(
|
| 41 |
+
"medical_documents",
|
| 42 |
+
sa.Column("id", sa.String(length=36), nullable=False),
|
| 43 |
+
sa.Column("title", sa.String(length=512), nullable=False),
|
| 44 |
+
sa.Column("source", sa.String(length=512), nullable=False),
|
| 45 |
+
sa.Column("source_type", sa.String(length=32), nullable=False),
|
| 46 |
+
sa.Column("authors", sa.Text(), nullable=True),
|
| 47 |
+
sa.Column("abstract", sa.Text(), nullable=True),
|
| 48 |
+
sa.Column("content_hash", sa.String(length=64), nullable=True),
|
| 49 |
+
sa.Column("page_count", sa.Integer(), nullable=True),
|
| 50 |
+
sa.Column("chunk_count", sa.Integer(), nullable=True),
|
| 51 |
+
sa.Column("parse_status", sa.String(length=32), nullable=False),
|
| 52 |
+
sa.Column("metadata_json", sa.JSON(), nullable=True),
|
| 53 |
+
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
| 54 |
+
sa.Column("indexed_at", sa.DateTime(timezone=True), nullable=True),
|
| 55 |
+
sa.PrimaryKeyConstraint("id"),
|
| 56 |
+
sa.UniqueConstraint("content_hash"),
|
| 57 |
)
|
| 58 |
+
op.create_index(op.f("ix_medical_documents_title"), "medical_documents", ["title"], unique=False)
|
| 59 |
|
| 60 |
op.create_table(
|
| 61 |
+
"sop_versions",
|
| 62 |
+
sa.Column("id", sa.String(length=36), nullable=False),
|
| 63 |
+
sa.Column("version_tag", sa.String(length=64), nullable=False),
|
| 64 |
+
sa.Column("parameters", sa.JSON(), nullable=False),
|
| 65 |
+
sa.Column("evaluation_scores", sa.JSON(), nullable=True),
|
| 66 |
+
sa.Column("parent_version", sa.String(length=64), nullable=True),
|
| 67 |
+
sa.Column("is_active", sa.Boolean(), nullable=False),
|
| 68 |
+
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
| 69 |
+
sa.PrimaryKeyConstraint("id"),
|
| 70 |
)
|
| 71 |
+
op.create_index(op.f("ix_sop_versions_version_tag"), "sop_versions", ["version_tag"], unique=True)
|
| 72 |
|
| 73 |
|
| 74 |
def downgrade() -> None:
|
| 75 |
+
op.drop_index(op.f("ix_sop_versions_version_tag"), table_name="sop_versions")
|
| 76 |
+
op.drop_table("sop_versions")
|
| 77 |
|
| 78 |
+
op.drop_index(op.f("ix_medical_documents_title"), table_name="medical_documents")
|
| 79 |
+
op.drop_table("medical_documents")
|
| 80 |
|
| 81 |
+
op.drop_index(op.f("ix_patient_analyses_request_id"), table_name="patient_analyses")
|
| 82 |
+
op.drop_table("patient_analyses")
|
api/app/__init__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
RagBot FastAPI Application
|
| 3 |
"""
|
|
|
|
| 4 |
__version__ = "1.0.0"
|
|
|
|
| 1 |
"""
|
| 2 |
RagBot FastAPI Application
|
| 3 |
"""
|
| 4 |
+
|
| 5 |
__version__ = "1.0.0"
|
api/app/main.py
CHANGED
|
@@ -17,10 +17,7 @@ from app.routes import analyze, biomarkers, health
|
|
| 17 |
from app.services.ragbot import get_ragbot_service
|
| 18 |
|
| 19 |
# Configure logging
|
| 20 |
-
logging.basicConfig(
|
| 21 |
-
level=logging.INFO,
|
| 22 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 23 |
-
)
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
|
|
@@ -28,6 +25,7 @@ logger = logging.getLogger(__name__)
|
|
| 28 |
# LIFESPAN EVENTS
|
| 29 |
# ============================================================================
|
| 30 |
|
|
|
|
| 31 |
@asynccontextmanager
|
| 32 |
async def lifespan(app: FastAPI):
|
| 33 |
"""
|
|
@@ -67,7 +65,7 @@ app = FastAPI(
|
|
| 67 |
lifespan=lifespan,
|
| 68 |
docs_url="/docs",
|
| 69 |
redoc_url="/redoc",
|
| 70 |
-
openapi_url="/openapi.json"
|
| 71 |
)
|
| 72 |
|
| 73 |
|
|
@@ -90,6 +88,7 @@ app.add_middleware(
|
|
| 90 |
# ERROR HANDLERS
|
| 91 |
# ============================================================================
|
| 92 |
|
|
|
|
| 93 |
@app.exception_handler(RequestValidationError)
|
| 94 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 95 |
"""Handle request validation errors"""
|
|
@@ -100,8 +99,8 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|
| 100 |
"error_code": "VALIDATION_ERROR",
|
| 101 |
"message": "Request validation failed",
|
| 102 |
"details": exc.errors(),
|
| 103 |
-
"body": exc.body
|
| 104 |
-
}
|
| 105 |
)
|
| 106 |
|
| 107 |
|
|
@@ -114,8 +113,8 @@ async def general_exception_handler(request: Request, exc: Exception):
|
|
| 114 |
content={
|
| 115 |
"status": "error",
|
| 116 |
"error_code": "INTERNAL_SERVER_ERROR",
|
| 117 |
-
"message": "An unexpected error occurred. Please try again later."
|
| 118 |
-
}
|
| 119 |
)
|
| 120 |
|
| 121 |
|
|
@@ -144,13 +143,9 @@ async def root():
|
|
| 144 |
"analyze_structured": "/api/v1/analyze/structured",
|
| 145 |
"example": "/api/v1/example",
|
| 146 |
"docs": "/docs",
|
| 147 |
-
"redoc": "/redoc"
|
| 148 |
-
},
|
| 149 |
-
"documentation": {
|
| 150 |
-
"swagger_ui": "/docs",
|
| 151 |
"redoc": "/redoc",
|
| 152 |
-
|
| 153 |
-
}
|
| 154 |
}
|
| 155 |
|
| 156 |
|
|
@@ -164,8 +159,8 @@ async def api_v1_info():
|
|
| 164 |
"GET /api/v1/biomarkers",
|
| 165 |
"POST /api/v1/analyze/natural",
|
| 166 |
"POST /api/v1/analyze/structured",
|
| 167 |
-
"GET /api/v1/example"
|
| 168 |
-
]
|
| 169 |
}
|
| 170 |
|
| 171 |
|
|
@@ -183,10 +178,4 @@ if __name__ == "__main__":
|
|
| 183 |
|
| 184 |
logger.info(f"Starting server on {host}:{port}")
|
| 185 |
|
| 186 |
-
uvicorn.run(
|
| 187 |
-
"app.main:app",
|
| 188 |
-
host=host,
|
| 189 |
-
port=port,
|
| 190 |
-
reload=reload,
|
| 191 |
-
log_level="info"
|
| 192 |
-
)
|
|
|
|
| 17 |
from app.services.ragbot import get_ragbot_service
|
| 18 |
|
| 19 |
# Configure logging
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
|
|
|
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
|
|
|
|
| 25 |
# LIFESPAN EVENTS
|
| 26 |
# ============================================================================
|
| 27 |
|
| 28 |
+
|
| 29 |
@asynccontextmanager
|
| 30 |
async def lifespan(app: FastAPI):
|
| 31 |
"""
|
|
|
|
| 65 |
lifespan=lifespan,
|
| 66 |
docs_url="/docs",
|
| 67 |
redoc_url="/redoc",
|
| 68 |
+
openapi_url="/openapi.json",
|
| 69 |
)
|
| 70 |
|
| 71 |
|
|
|
|
| 88 |
# ERROR HANDLERS
|
| 89 |
# ============================================================================
|
| 90 |
|
| 91 |
+
|
| 92 |
@app.exception_handler(RequestValidationError)
|
| 93 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 94 |
"""Handle request validation errors"""
|
|
|
|
| 99 |
"error_code": "VALIDATION_ERROR",
|
| 100 |
"message": "Request validation failed",
|
| 101 |
"details": exc.errors(),
|
| 102 |
+
"body": exc.body,
|
| 103 |
+
},
|
| 104 |
)
|
| 105 |
|
| 106 |
|
|
|
|
| 113 |
content={
|
| 114 |
"status": "error",
|
| 115 |
"error_code": "INTERNAL_SERVER_ERROR",
|
| 116 |
+
"message": "An unexpected error occurred. Please try again later.",
|
| 117 |
+
},
|
| 118 |
)
|
| 119 |
|
| 120 |
|
|
|
|
| 143 |
"analyze_structured": "/api/v1/analyze/structured",
|
| 144 |
"example": "/api/v1/example",
|
| 145 |
"docs": "/docs",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
"redoc": "/redoc",
|
| 147 |
+
},
|
| 148 |
+
"documentation": {"swagger_ui": "/docs", "redoc": "/redoc", "openapi_schema": "/openapi.json"},
|
| 149 |
}
|
| 150 |
|
| 151 |
|
|
|
|
| 159 |
"GET /api/v1/biomarkers",
|
| 160 |
"POST /api/v1/analyze/natural",
|
| 161 |
"POST /api/v1/analyze/structured",
|
| 162 |
+
"GET /api/v1/example",
|
| 163 |
+
],
|
| 164 |
}
|
| 165 |
|
| 166 |
|
|
|
|
| 178 |
|
| 179 |
logger.info(f"Starting server on {host}:{port}")
|
| 180 |
|
| 181 |
+
uvicorn.run("app.main:app", host=host, port=port, reload=reload, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/app/routes/analyze.py
CHANGED
|
@@ -18,13 +18,13 @@ router = APIRouter(prefix="/api/v1", tags=["analysis"])
|
|
| 18 |
async def analyze_natural(request: NaturalAnalysisRequest):
|
| 19 |
"""
|
| 20 |
Analyze biomarkers from natural language input.
|
| 21 |
-
|
| 22 |
**Flow:**
|
| 23 |
1. Extract biomarkers from natural language using LLM
|
| 24 |
2. Predict disease using rule-based or ML model
|
| 25 |
3. Run complete RAG workflow analysis
|
| 26 |
4. Return comprehensive results
|
| 27 |
-
|
| 28 |
**Example request:**
|
| 29 |
```json
|
| 30 |
{
|
|
@@ -36,7 +36,7 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 36 |
}
|
| 37 |
}
|
| 38 |
```
|
| 39 |
-
|
| 40 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 41 |
"""
|
| 42 |
|
|
@@ -46,15 +46,12 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 46 |
if not ragbot_service.is_ready():
|
| 47 |
raise HTTPException(
|
| 48 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 49 |
-
detail="RagBot service not initialized. Please try again in a moment."
|
| 50 |
)
|
| 51 |
|
| 52 |
# Extract biomarkers from natural language
|
| 53 |
ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| 54 |
-
biomarkers, extracted_context, error = extract_biomarkers(
|
| 55 |
-
request.message,
|
| 56 |
-
ollama_base_url=ollama_base_url
|
| 57 |
-
)
|
| 58 |
|
| 59 |
if error:
|
| 60 |
raise HTTPException(
|
|
@@ -63,8 +60,8 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 63 |
"error_code": "EXTRACTION_FAILED",
|
| 64 |
"message": error,
|
| 65 |
"input_received": request.message[:100],
|
| 66 |
-
"suggestion": "Try: 'My glucose is 140 and HbA1c is 7.5'"
|
| 67 |
-
}
|
| 68 |
)
|
| 69 |
|
| 70 |
if not biomarkers:
|
|
@@ -74,8 +71,8 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 74 |
"error_code": "NO_BIOMARKERS_FOUND",
|
| 75 |
"message": "Could not extract any biomarkers from your message",
|
| 76 |
"input_received": request.message[:100],
|
| 77 |
-
"suggestion": "Include specific biomarker values like 'glucose is 140'"
|
| 78 |
-
}
|
| 79 |
)
|
| 80 |
|
| 81 |
# Merge extracted context with request context
|
|
@@ -91,7 +88,7 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 91 |
biomarkers=biomarkers,
|
| 92 |
patient_context=patient_context,
|
| 93 |
model_prediction=model_prediction,
|
| 94 |
-
extracted_biomarkers=biomarkers # Keep original extraction
|
| 95 |
)
|
| 96 |
|
| 97 |
return response
|
|
@@ -102,22 +99,22 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 102 |
detail={
|
| 103 |
"error_code": "ANALYSIS_FAILED",
|
| 104 |
"message": f"Analysis workflow failed: {e!s}",
|
| 105 |
-
"biomarkers_received": biomarkers
|
| 106 |
-
}
|
| 107 |
-
)
|
| 108 |
|
| 109 |
|
| 110 |
@router.post("/analyze/structured", response_model=AnalysisResponse)
|
| 111 |
async def analyze_structured(request: StructuredAnalysisRequest):
|
| 112 |
"""
|
| 113 |
Analyze biomarkers from structured input (skip extraction).
|
| 114 |
-
|
| 115 |
**Flow:**
|
| 116 |
1. Use provided biomarker dictionary directly
|
| 117 |
2. Predict disease using rule-based or ML model
|
| 118 |
3. Run complete RAG workflow analysis
|
| 119 |
4. Return comprehensive results
|
| 120 |
-
|
| 121 |
**Example request:**
|
| 122 |
```json
|
| 123 |
{
|
|
@@ -135,7 +132,7 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 135 |
}
|
| 136 |
}
|
| 137 |
```
|
| 138 |
-
|
| 139 |
Use this endpoint when you already have structured biomarker data.
|
| 140 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 141 |
"""
|
|
@@ -146,7 +143,7 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 146 |
if not ragbot_service.is_ready():
|
| 147 |
raise HTTPException(
|
| 148 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 149 |
-
detail="RagBot service not initialized. Please try again in a moment."
|
| 150 |
)
|
| 151 |
|
| 152 |
# Validate biomarkers
|
|
@@ -156,8 +153,8 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 156 |
detail={
|
| 157 |
"error_code": "NO_BIOMARKERS",
|
| 158 |
"message": "Biomarkers dictionary cannot be empty",
|
| 159 |
-
"suggestion": "Provide at least one biomarker with a numeric value"
|
| 160 |
-
}
|
| 161 |
)
|
| 162 |
|
| 163 |
# Patient context
|
|
@@ -172,7 +169,7 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 172 |
biomarkers=request.biomarkers,
|
| 173 |
patient_context=patient_context,
|
| 174 |
model_prediction=model_prediction,
|
| 175 |
-
extracted_biomarkers=None # No extraction for structured input
|
| 176 |
)
|
| 177 |
|
| 178 |
return response
|
|
@@ -183,26 +180,26 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 183 |
detail={
|
| 184 |
"error_code": "ANALYSIS_FAILED",
|
| 185 |
"message": f"Analysis workflow failed: {e!s}",
|
| 186 |
-
"biomarkers_received": request.biomarkers
|
| 187 |
-
}
|
| 188 |
-
)
|
| 189 |
|
| 190 |
|
| 191 |
@router.get("/example", response_model=AnalysisResponse)
|
| 192 |
async def get_example():
|
| 193 |
"""
|
| 194 |
Get example diabetes case analysis.
|
| 195 |
-
|
| 196 |
**Pre-run example case:**
|
| 197 |
- 52-year-old male patient
|
| 198 |
- Elevated glucose and HbA1c
|
| 199 |
- Type 2 Diabetes prediction
|
| 200 |
-
|
| 201 |
Useful for:
|
| 202 |
- Testing API integration
|
| 203 |
- Understanding response format
|
| 204 |
- Demo purposes
|
| 205 |
-
|
| 206 |
Same as CLI chatbot 'example' command.
|
| 207 |
"""
|
| 208 |
|
|
@@ -212,7 +209,7 @@ async def get_example():
|
|
| 212 |
if not ragbot_service.is_ready():
|
| 213 |
raise HTTPException(
|
| 214 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 215 |
-
detail="RagBot service not initialized. Please try again in a moment."
|
| 216 |
)
|
| 217 |
|
| 218 |
# Example biomarkers (Type 2 Diabetes patient)
|
|
@@ -227,15 +224,10 @@ async def get_example():
|
|
| 227 |
"LDL Cholesterol": 165.0,
|
| 228 |
"BMI": 31.2,
|
| 229 |
"Systolic Blood Pressure": 142.0,
|
| 230 |
-
"Diastolic Blood Pressure": 88.0
|
| 231 |
}
|
| 232 |
|
| 233 |
-
patient_context = {
|
| 234 |
-
"age": 52,
|
| 235 |
-
"gender": "male",
|
| 236 |
-
"bmi": 31.2,
|
| 237 |
-
"patient_id": "EXAMPLE-001"
|
| 238 |
-
}
|
| 239 |
|
| 240 |
model_prediction = {
|
| 241 |
"disease": "Diabetes",
|
|
@@ -245,8 +237,8 @@ async def get_example():
|
|
| 245 |
"Heart Disease": 0.08,
|
| 246 |
"Anemia": 0.03,
|
| 247 |
"Thalassemia": 0.01,
|
| 248 |
-
"Thrombocytopenia": 0.01
|
| 249 |
-
}
|
| 250 |
}
|
| 251 |
|
| 252 |
try:
|
|
@@ -255,7 +247,7 @@ async def get_example():
|
|
| 255 |
biomarkers=biomarkers,
|
| 256 |
patient_context=patient_context,
|
| 257 |
model_prediction=model_prediction,
|
| 258 |
-
extracted_biomarkers=None
|
| 259 |
)
|
| 260 |
|
| 261 |
return response
|
|
@@ -263,8 +255,5 @@ async def get_example():
|
|
| 263 |
except Exception as e:
|
| 264 |
raise HTTPException(
|
| 265 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 266 |
-
detail={
|
| 267 |
-
|
| 268 |
-
"message": f"Example analysis failed: {e!s}"
|
| 269 |
-
}
|
| 270 |
-
)
|
|
|
|
| 18 |
async def analyze_natural(request: NaturalAnalysisRequest):
|
| 19 |
"""
|
| 20 |
Analyze biomarkers from natural language input.
|
| 21 |
+
|
| 22 |
**Flow:**
|
| 23 |
1. Extract biomarkers from natural language using LLM
|
| 24 |
2. Predict disease using rule-based or ML model
|
| 25 |
3. Run complete RAG workflow analysis
|
| 26 |
4. Return comprehensive results
|
| 27 |
+
|
| 28 |
**Example request:**
|
| 29 |
```json
|
| 30 |
{
|
|
|
|
| 36 |
}
|
| 37 |
}
|
| 38 |
```
|
| 39 |
+
|
| 40 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 41 |
"""
|
| 42 |
|
|
|
|
| 46 |
if not ragbot_service.is_ready():
|
| 47 |
raise HTTPException(
|
| 48 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 49 |
+
detail="RagBot service not initialized. Please try again in a moment.",
|
| 50 |
)
|
| 51 |
|
| 52 |
# Extract biomarkers from natural language
|
| 53 |
ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| 54 |
+
biomarkers, extracted_context, error = extract_biomarkers(request.message, ollama_base_url=ollama_base_url)
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
if error:
|
| 57 |
raise HTTPException(
|
|
|
|
| 60 |
"error_code": "EXTRACTION_FAILED",
|
| 61 |
"message": error,
|
| 62 |
"input_received": request.message[:100],
|
| 63 |
+
"suggestion": "Try: 'My glucose is 140 and HbA1c is 7.5'",
|
| 64 |
+
},
|
| 65 |
)
|
| 66 |
|
| 67 |
if not biomarkers:
|
|
|
|
| 71 |
"error_code": "NO_BIOMARKERS_FOUND",
|
| 72 |
"message": "Could not extract any biomarkers from your message",
|
| 73 |
"input_received": request.message[:100],
|
| 74 |
+
"suggestion": "Include specific biomarker values like 'glucose is 140'",
|
| 75 |
+
},
|
| 76 |
)
|
| 77 |
|
| 78 |
# Merge extracted context with request context
|
|
|
|
| 88 |
biomarkers=biomarkers,
|
| 89 |
patient_context=patient_context,
|
| 90 |
model_prediction=model_prediction,
|
| 91 |
+
extracted_biomarkers=biomarkers, # Keep original extraction
|
| 92 |
)
|
| 93 |
|
| 94 |
return response
|
|
|
|
| 99 |
detail={
|
| 100 |
"error_code": "ANALYSIS_FAILED",
|
| 101 |
"message": f"Analysis workflow failed: {e!s}",
|
| 102 |
+
"biomarkers_received": biomarkers,
|
| 103 |
+
},
|
| 104 |
+
) from e
|
| 105 |
|
| 106 |
|
| 107 |
@router.post("/analyze/structured", response_model=AnalysisResponse)
|
| 108 |
async def analyze_structured(request: StructuredAnalysisRequest):
|
| 109 |
"""
|
| 110 |
Analyze biomarkers from structured input (skip extraction).
|
| 111 |
+
|
| 112 |
**Flow:**
|
| 113 |
1. Use provided biomarker dictionary directly
|
| 114 |
2. Predict disease using rule-based or ML model
|
| 115 |
3. Run complete RAG workflow analysis
|
| 116 |
4. Return comprehensive results
|
| 117 |
+
|
| 118 |
**Example request:**
|
| 119 |
```json
|
| 120 |
{
|
|
|
|
| 132 |
}
|
| 133 |
}
|
| 134 |
```
|
| 135 |
+
|
| 136 |
Use this endpoint when you already have structured biomarker data.
|
| 137 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 138 |
"""
|
|
|
|
| 143 |
if not ragbot_service.is_ready():
|
| 144 |
raise HTTPException(
|
| 145 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 146 |
+
detail="RagBot service not initialized. Please try again in a moment.",
|
| 147 |
)
|
| 148 |
|
| 149 |
# Validate biomarkers
|
|
|
|
| 153 |
detail={
|
| 154 |
"error_code": "NO_BIOMARKERS",
|
| 155 |
"message": "Biomarkers dictionary cannot be empty",
|
| 156 |
+
"suggestion": "Provide at least one biomarker with a numeric value",
|
| 157 |
+
},
|
| 158 |
)
|
| 159 |
|
| 160 |
# Patient context
|
|
|
|
| 169 |
biomarkers=request.biomarkers,
|
| 170 |
patient_context=patient_context,
|
| 171 |
model_prediction=model_prediction,
|
| 172 |
+
extracted_biomarkers=None, # No extraction for structured input
|
| 173 |
)
|
| 174 |
|
| 175 |
return response
|
|
|
|
| 180 |
detail={
|
| 181 |
"error_code": "ANALYSIS_FAILED",
|
| 182 |
"message": f"Analysis workflow failed: {e!s}",
|
| 183 |
+
"biomarkers_received": request.biomarkers,
|
| 184 |
+
},
|
| 185 |
+
) from e
|
| 186 |
|
| 187 |
|
| 188 |
@router.get("/example", response_model=AnalysisResponse)
|
| 189 |
async def get_example():
|
| 190 |
"""
|
| 191 |
Get example diabetes case analysis.
|
| 192 |
+
|
| 193 |
**Pre-run example case:**
|
| 194 |
- 52-year-old male patient
|
| 195 |
- Elevated glucose and HbA1c
|
| 196 |
- Type 2 Diabetes prediction
|
| 197 |
+
|
| 198 |
Useful for:
|
| 199 |
- Testing API integration
|
| 200 |
- Understanding response format
|
| 201 |
- Demo purposes
|
| 202 |
+
|
| 203 |
Same as CLI chatbot 'example' command.
|
| 204 |
"""
|
| 205 |
|
|
|
|
| 209 |
if not ragbot_service.is_ready():
|
| 210 |
raise HTTPException(
|
| 211 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 212 |
+
detail="RagBot service not initialized. Please try again in a moment.",
|
| 213 |
)
|
| 214 |
|
| 215 |
# Example biomarkers (Type 2 Diabetes patient)
|
|
|
|
| 224 |
"LDL Cholesterol": 165.0,
|
| 225 |
"BMI": 31.2,
|
| 226 |
"Systolic Blood Pressure": 142.0,
|
| 227 |
+
"Diastolic Blood Pressure": 88.0,
|
| 228 |
}
|
| 229 |
|
| 230 |
+
patient_context = {"age": 52, "gender": "male", "bmi": 31.2, "patient_id": "EXAMPLE-001"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
model_prediction = {
|
| 233 |
"disease": "Diabetes",
|
|
|
|
| 237 |
"Heart Disease": 0.08,
|
| 238 |
"Anemia": 0.03,
|
| 239 |
"Thalassemia": 0.01,
|
| 240 |
+
"Thrombocytopenia": 0.01,
|
| 241 |
+
},
|
| 242 |
}
|
| 243 |
|
| 244 |
try:
|
|
|
|
| 247 |
biomarkers=biomarkers,
|
| 248 |
patient_context=patient_context,
|
| 249 |
model_prediction=model_prediction,
|
| 250 |
+
extracted_biomarkers=None,
|
| 251 |
)
|
| 252 |
|
| 253 |
return response
|
|
|
|
| 255 |
except Exception as e:
|
| 256 |
raise HTTPException(
|
| 257 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 258 |
+
detail={"error_code": "EXAMPLE_FAILED", "message": f"Example analysis failed: {e!s}"},
|
| 259 |
+
) from e
|
|
|
|
|
|
|
|
|
api/app/routes/biomarkers.py
CHANGED
|
@@ -17,13 +17,13 @@ router = APIRouter(prefix="/api/v1", tags=["biomarkers"])
|
|
| 17 |
async def list_biomarkers():
|
| 18 |
"""
|
| 19 |
Get list of all supported biomarkers with reference ranges.
|
| 20 |
-
|
| 21 |
Returns comprehensive information about all 24 biomarkers:
|
| 22 |
- Name and unit
|
| 23 |
- Normal reference ranges (gender-specific if applicable)
|
| 24 |
- Critical thresholds
|
| 25 |
- Clinical significance
|
| 26 |
-
|
| 27 |
Useful for:
|
| 28 |
- Frontend validation
|
| 29 |
- Understanding what biomarkers can be analyzed
|
|
@@ -48,18 +48,12 @@ async def list_biomarkers():
|
|
| 48 |
if "male" in normal_range_data or "female" in normal_range_data:
|
| 49 |
# Gender-specific ranges
|
| 50 |
reference_range = BiomarkerReferenceRange(
|
| 51 |
-
min=None,
|
| 52 |
-
max=None,
|
| 53 |
-
male=normal_range_data.get("male"),
|
| 54 |
-
female=normal_range_data.get("female")
|
| 55 |
)
|
| 56 |
else:
|
| 57 |
# Universal range
|
| 58 |
reference_range = BiomarkerReferenceRange(
|
| 59 |
-
min=normal_range_data.get("min"),
|
| 60 |
-
max=normal_range_data.get("max"),
|
| 61 |
-
male=None,
|
| 62 |
-
female=None
|
| 63 |
)
|
| 64 |
|
| 65 |
biomarker_info = BiomarkerInfo(
|
|
@@ -70,25 +64,17 @@ async def list_biomarkers():
|
|
| 70 |
critical_high=info.get("critical_high"),
|
| 71 |
gender_specific=info.get("gender_specific", False),
|
| 72 |
description=info.get("description", ""),
|
| 73 |
-
clinical_significance=info.get("clinical_significance", {})
|
| 74 |
)
|
| 75 |
|
| 76 |
biomarkers_list.append(biomarker_info)
|
| 77 |
|
| 78 |
return BiomarkersListResponse(
|
| 79 |
-
biomarkers=biomarkers_list,
|
| 80 |
-
total_count=len(biomarkers_list),
|
| 81 |
-
timestamp=datetime.now().isoformat()
|
| 82 |
)
|
| 83 |
|
| 84 |
except FileNotFoundError:
|
| 85 |
-
raise HTTPException(
|
| 86 |
-
status_code=500,
|
| 87 |
-
detail="Biomarker configuration file not found"
|
| 88 |
-
)
|
| 89 |
|
| 90 |
except Exception as e:
|
| 91 |
-
raise HTTPException(
|
| 92 |
-
status_code=500,
|
| 93 |
-
detail=f"Failed to load biomarkers: {e!s}"
|
| 94 |
-
)
|
|
|
|
| 17 |
async def list_biomarkers():
|
| 18 |
"""
|
| 19 |
Get list of all supported biomarkers with reference ranges.
|
| 20 |
+
|
| 21 |
Returns comprehensive information about all 24 biomarkers:
|
| 22 |
- Name and unit
|
| 23 |
- Normal reference ranges (gender-specific if applicable)
|
| 24 |
- Critical thresholds
|
| 25 |
- Clinical significance
|
| 26 |
+
|
| 27 |
Useful for:
|
| 28 |
- Frontend validation
|
| 29 |
- Understanding what biomarkers can be analyzed
|
|
|
|
| 48 |
if "male" in normal_range_data or "female" in normal_range_data:
|
| 49 |
# Gender-specific ranges
|
| 50 |
reference_range = BiomarkerReferenceRange(
|
| 51 |
+
min=None, max=None, male=normal_range_data.get("male"), female=normal_range_data.get("female")
|
|
|
|
|
|
|
|
|
|
| 52 |
)
|
| 53 |
else:
|
| 54 |
# Universal range
|
| 55 |
reference_range = BiomarkerReferenceRange(
|
| 56 |
+
min=normal_range_data.get("min"), max=normal_range_data.get("max"), male=None, female=None
|
|
|
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
biomarker_info = BiomarkerInfo(
|
|
|
|
| 64 |
critical_high=info.get("critical_high"),
|
| 65 |
gender_specific=info.get("gender_specific", False),
|
| 66 |
description=info.get("description", ""),
|
| 67 |
+
clinical_significance=info.get("clinical_significance", {}),
|
| 68 |
)
|
| 69 |
|
| 70 |
biomarkers_list.append(biomarker_info)
|
| 71 |
|
| 72 |
return BiomarkersListResponse(
|
| 73 |
+
biomarkers=biomarkers_list, total_count=len(biomarkers_list), timestamp=datetime.now().isoformat()
|
|
|
|
|
|
|
| 74 |
)
|
| 75 |
|
| 76 |
except FileNotFoundError:
|
| 77 |
+
raise HTTPException(status_code=500, detail="Biomarker configuration file not found")
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
except Exception as e:
|
| 80 |
+
raise HTTPException(status_code=500, detail=f"Failed to load biomarkers: {e!s}") from e
|
|
|
|
|
|
|
|
|
api/app/routes/health.py
CHANGED
|
@@ -17,13 +17,13 @@ router = APIRouter(prefix="/api/v1", tags=["health"])
|
|
| 17 |
async def health_check():
|
| 18 |
"""
|
| 19 |
Check API health status.
|
| 20 |
-
|
| 21 |
Verifies:
|
| 22 |
- LLM API connection (Groq/Gemini)
|
| 23 |
- Vector store loaded
|
| 24 |
- Available models
|
| 25 |
- Service uptime
|
| 26 |
-
|
| 27 |
Returns health status with component details.
|
| 28 |
"""
|
| 29 |
ragbot_service = get_ragbot_service()
|
|
@@ -69,5 +69,5 @@ async def health_check():
|
|
| 69 |
vector_store_loaded=vector_store_loaded,
|
| 70 |
available_models=available_models,
|
| 71 |
uptime_seconds=ragbot_service.get_uptime_seconds(),
|
| 72 |
-
version=__version__
|
| 73 |
)
|
|
|
|
| 17 |
async def health_check():
|
| 18 |
"""
|
| 19 |
Check API health status.
|
| 20 |
+
|
| 21 |
Verifies:
|
| 22 |
- LLM API connection (Groq/Gemini)
|
| 23 |
- Vector store loaded
|
| 24 |
- Available models
|
| 25 |
- Service uptime
|
| 26 |
+
|
| 27 |
Returns health status with component details.
|
| 28 |
"""
|
| 29 |
ragbot_service = get_ragbot_service()
|
|
|
|
| 69 |
vector_store_loaded=vector_store_loaded,
|
| 70 |
available_models=available_models,
|
| 71 |
uptime_seconds=ragbot_service.get_uptime_seconds(),
|
| 72 |
+
version=__version__,
|
| 73 |
)
|
api/app/services/extraction.py
CHANGED
|
@@ -54,6 +54,7 @@ If you cannot find any biomarkers, return {{"biomarkers": {{}}, "patient_context
|
|
| 54 |
# EXTRACTION HELPERS
|
| 55 |
# ============================================================================
|
| 56 |
|
|
|
|
| 57 |
def _parse_llm_json(content: str) -> dict[str, Any]:
|
| 58 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 59 |
text = content.strip()
|
|
@@ -69,7 +70,7 @@ def _parse_llm_json(content: str) -> dict[str, Any]:
|
|
| 69 |
left = text.find("{")
|
| 70 |
right = text.rfind("}")
|
| 71 |
if left != -1 and right != -1 and right > left:
|
| 72 |
-
return json.loads(text[left:right + 1])
|
| 73 |
raise
|
| 74 |
|
| 75 |
|
|
@@ -77,23 +78,24 @@ def _parse_llm_json(content: str) -> dict[str, Any]:
|
|
| 77 |
# EXTRACTION FUNCTION
|
| 78 |
# ============================================================================
|
| 79 |
|
|
|
|
| 80 |
def extract_biomarkers(
|
| 81 |
user_message: str,
|
| 82 |
-
ollama_base_url: str = None # Kept for backward compatibility, ignored
|
| 83 |
) -> tuple[dict[str, float], dict[str, Any], str]:
|
| 84 |
"""
|
| 85 |
Extract biomarker values from natural language using LLM.
|
| 86 |
-
|
| 87 |
Args:
|
| 88 |
user_message: Natural language text containing biomarker information
|
| 89 |
ollama_base_url: DEPRECATED - uses cloud LLM (Groq/Gemini) instead
|
| 90 |
-
|
| 91 |
Returns:
|
| 92 |
Tuple of (biomarkers_dict, patient_context_dict, error_message)
|
| 93 |
- biomarkers_dict: Normalized biomarker names -> values
|
| 94 |
- patient_context_dict: Extracted patient context (age, gender, BMI)
|
| 95 |
- error_message: Empty string if successful, error description if failed
|
| 96 |
-
|
| 97 |
Example:
|
| 98 |
>>> biomarkers, context, error = extract_biomarkers("My glucose is 185 and HbA1c is 8.2")
|
| 99 |
>>> print(biomarkers)
|
|
@@ -143,24 +145,19 @@ def extract_biomarkers(
|
|
| 143 |
# SIMPLE DISEASE PREDICTION (Fallback)
|
| 144 |
# ============================================================================
|
| 145 |
|
|
|
|
| 146 |
def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 147 |
"""
|
| 148 |
Simple rule-based disease prediction based on key biomarkers.
|
| 149 |
Used as a fallback when no ML model is available.
|
| 150 |
-
|
| 151 |
Args:
|
| 152 |
biomarkers: Dictionary of biomarker names to values
|
| 153 |
-
|
| 154 |
Returns:
|
| 155 |
Dictionary with disease, confidence, and probabilities
|
| 156 |
"""
|
| 157 |
-
scores = {
|
| 158 |
-
"Diabetes": 0.0,
|
| 159 |
-
"Anemia": 0.0,
|
| 160 |
-
"Heart Disease": 0.0,
|
| 161 |
-
"Thrombocytopenia": 0.0,
|
| 162 |
-
"Thalassemia": 0.0
|
| 163 |
-
}
|
| 164 |
|
| 165 |
# Helper: check both abbreviated and normalized biomarker names
|
| 166 |
# Returns None when biomarker is not present (avoids false triggers)
|
|
@@ -230,8 +227,4 @@ def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
|
| 230 |
else:
|
| 231 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 232 |
|
| 233 |
-
return {
|
| 234 |
-
"disease": top_disease,
|
| 235 |
-
"confidence": confidence,
|
| 236 |
-
"probabilities": probabilities
|
| 237 |
-
}
|
|
|
|
| 54 |
# EXTRACTION HELPERS
|
| 55 |
# ============================================================================
|
| 56 |
|
| 57 |
+
|
| 58 |
def _parse_llm_json(content: str) -> dict[str, Any]:
|
| 59 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 60 |
text = content.strip()
|
|
|
|
| 70 |
left = text.find("{")
|
| 71 |
right = text.rfind("}")
|
| 72 |
if left != -1 and right != -1 and right > left:
|
| 73 |
+
return json.loads(text[left : right + 1])
|
| 74 |
raise
|
| 75 |
|
| 76 |
|
|
|
|
| 78 |
# EXTRACTION FUNCTION
|
| 79 |
# ============================================================================
|
| 80 |
|
| 81 |
+
|
| 82 |
def extract_biomarkers(
|
| 83 |
user_message: str,
|
| 84 |
+
ollama_base_url: str | None = None, # Kept for backward compatibility, ignored
|
| 85 |
) -> tuple[dict[str, float], dict[str, Any], str]:
|
| 86 |
"""
|
| 87 |
Extract biomarker values from natural language using LLM.
|
| 88 |
+
|
| 89 |
Args:
|
| 90 |
user_message: Natural language text containing biomarker information
|
| 91 |
ollama_base_url: DEPRECATED - uses cloud LLM (Groq/Gemini) instead
|
| 92 |
+
|
| 93 |
Returns:
|
| 94 |
Tuple of (biomarkers_dict, patient_context_dict, error_message)
|
| 95 |
- biomarkers_dict: Normalized biomarker names -> values
|
| 96 |
- patient_context_dict: Extracted patient context (age, gender, BMI)
|
| 97 |
- error_message: Empty string if successful, error description if failed
|
| 98 |
+
|
| 99 |
Example:
|
| 100 |
>>> biomarkers, context, error = extract_biomarkers("My glucose is 185 and HbA1c is 8.2")
|
| 101 |
>>> print(biomarkers)
|
|
|
|
| 145 |
# SIMPLE DISEASE PREDICTION (Fallback)
|
| 146 |
# ============================================================================
|
| 147 |
|
| 148 |
+
|
| 149 |
def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 150 |
"""
|
| 151 |
Simple rule-based disease prediction based on key biomarkers.
|
| 152 |
Used as a fallback when no ML model is available.
|
| 153 |
+
|
| 154 |
Args:
|
| 155 |
biomarkers: Dictionary of biomarker names to values
|
| 156 |
+
|
| 157 |
Returns:
|
| 158 |
Dictionary with disease, confidence, and probabilities
|
| 159 |
"""
|
| 160 |
+
scores = {"Diabetes": 0.0, "Anemia": 0.0, "Heart Disease": 0.0, "Thrombocytopenia": 0.0, "Thalassemia": 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Helper: check both abbreviated and normalized biomarker names
|
| 163 |
# Returns None when biomarker is not present (avoids false triggers)
|
|
|
|
| 227 |
else:
|
| 228 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 229 |
|
| 230 |
+
return {"disease": top_disease, "confidence": confidence, "probabilities": probabilities}
|
|
|
|
|
|
|
|
|
|
|
|
api/app/services/ragbot.py
CHANGED
|
@@ -94,17 +94,17 @@ class RagBotService:
|
|
| 94 |
biomarkers: dict[str, float],
|
| 95 |
patient_context: dict[str, Any],
|
| 96 |
model_prediction: dict[str, Any],
|
| 97 |
-
extracted_biomarkers: dict[str, float] = None
|
| 98 |
) -> AnalysisResponse:
|
| 99 |
"""
|
| 100 |
Run complete analysis workflow and format full detailed response.
|
| 101 |
-
|
| 102 |
Args:
|
| 103 |
biomarkers: Dictionary of biomarker names to values
|
| 104 |
patient_context: Patient demographic information
|
| 105 |
model_prediction: Disease prediction (disease, confidence, probabilities)
|
| 106 |
extracted_biomarkers: Original extracted biomarkers (for natural language input)
|
| 107 |
-
|
| 108 |
Returns:
|
| 109 |
Complete AnalysisResponse with all details
|
| 110 |
"""
|
|
@@ -117,9 +117,7 @@ class RagBotService:
|
|
| 117 |
try:
|
| 118 |
# Create PatientInput
|
| 119 |
patient_input = PatientInput(
|
| 120 |
-
biomarkers=biomarkers,
|
| 121 |
-
model_prediction=model_prediction,
|
| 122 |
-
patient_context=patient_context
|
| 123 |
)
|
| 124 |
|
| 125 |
# Run workflow
|
|
@@ -136,7 +134,7 @@ class RagBotService:
|
|
| 136 |
extracted_biomarkers=extracted_biomarkers,
|
| 137 |
patient_context=patient_context,
|
| 138 |
model_prediction=model_prediction,
|
| 139 |
-
processing_time_ms=processing_time_ms
|
| 140 |
)
|
| 141 |
|
| 142 |
return response
|
|
@@ -153,12 +151,12 @@ class RagBotService:
|
|
| 153 |
extracted_biomarkers: dict[str, float],
|
| 154 |
patient_context: dict[str, Any],
|
| 155 |
model_prediction: dict[str, Any],
|
| 156 |
-
processing_time_ms: float
|
| 157 |
) -> AnalysisResponse:
|
| 158 |
"""
|
| 159 |
Format complete detailed response from workflow result.
|
| 160 |
Preserves ALL data from workflow execution.
|
| 161 |
-
|
| 162 |
workflow_result is now the full LangGraph state dict containing:
|
| 163 |
- final_response: dict from response_synthesizer
|
| 164 |
- agent_outputs: list of AgentOutput objects
|
|
@@ -174,7 +172,7 @@ class RagBotService:
|
|
| 174 |
prediction = Prediction(
|
| 175 |
disease=model_prediction["disease"],
|
| 176 |
confidence=model_prediction["confidence"],
|
| 177 |
-
probabilities=model_prediction.get("probabilities", {})
|
| 178 |
)
|
| 179 |
|
| 180 |
# Biomarker flags: prefer state-level data (BiomarkerFlag objects from validator),
|
|
@@ -183,7 +181,7 @@ class RagBotService:
|
|
| 183 |
if state_flags:
|
| 184 |
biomarker_flags = []
|
| 185 |
for flag in state_flags:
|
| 186 |
-
if hasattr(flag,
|
| 187 |
biomarker_flags.append(BiomarkerFlag(**flag.model_dump()))
|
| 188 |
elif isinstance(flag, dict):
|
| 189 |
biomarker_flags.append(BiomarkerFlag(**flag))
|
|
@@ -201,7 +199,7 @@ class RagBotService:
|
|
| 201 |
if state_alerts:
|
| 202 |
safety_alerts = []
|
| 203 |
for alert in state_alerts:
|
| 204 |
-
if hasattr(alert,
|
| 205 |
safety_alerts.append(SafetyAlert(**alert.model_dump()))
|
| 206 |
elif isinstance(alert, dict):
|
| 207 |
safety_alerts.append(SafetyAlert(**alert))
|
|
@@ -230,7 +228,7 @@ class RagBotService:
|
|
| 230 |
disease_explanation = DiseaseExplanation(
|
| 231 |
pathophysiology=disease_exp_data.get("pathophysiology", ""),
|
| 232 |
citations=disease_exp_data.get("citations", []),
|
| 233 |
-
retrieved_chunks=disease_exp_data.get("retrieved_chunks")
|
| 234 |
)
|
| 235 |
|
| 236 |
# Recommendations from synthesizer
|
|
@@ -243,7 +241,7 @@ class RagBotService:
|
|
| 243 |
immediate_actions=recs_data.get("immediate_actions", []),
|
| 244 |
lifestyle_changes=recs_data.get("lifestyle_changes", []),
|
| 245 |
monitoring=recs_data.get("monitoring", []),
|
| 246 |
-
follow_up=recs_data.get("follow_up")
|
| 247 |
)
|
| 248 |
|
| 249 |
# Confidence assessment from synthesizer
|
|
@@ -254,7 +252,7 @@ class RagBotService:
|
|
| 254 |
prediction_reliability=conf_data.get("prediction_reliability", "UNKNOWN"),
|
| 255 |
evidence_strength=conf_data.get("evidence_strength", "UNKNOWN"),
|
| 256 |
limitations=conf_data.get("limitations", []),
|
| 257 |
-
reasoning=conf_data.get("reasoning")
|
| 258 |
)
|
| 259 |
|
| 260 |
# Alternative diagnoses
|
|
@@ -270,14 +268,14 @@ class RagBotService:
|
|
| 270 |
disease_explanation=disease_explanation,
|
| 271 |
recommendations=recommendations,
|
| 272 |
confidence_assessment=confidence_assessment,
|
| 273 |
-
alternative_diagnoses=alternative_diagnoses
|
| 274 |
)
|
| 275 |
|
| 276 |
# Agent outputs from state (these are src.state.AgentOutput objects)
|
| 277 |
agent_outputs_data = workflow_result.get("agent_outputs", [])
|
| 278 |
agent_outputs = []
|
| 279 |
for agent_out in agent_outputs_data:
|
| 280 |
-
if hasattr(agent_out,
|
| 281 |
agent_outputs.append(AgentOutput(**agent_out.model_dump()))
|
| 282 |
elif isinstance(agent_out, dict):
|
| 283 |
agent_outputs.append(AgentOutput(**agent_out))
|
|
@@ -287,7 +285,7 @@ class RagBotService:
|
|
| 287 |
"sop_version": workflow_result.get("sop_version"),
|
| 288 |
"processing_timestamp": workflow_result.get("processing_timestamp"),
|
| 289 |
"agents_executed": len(agent_outputs),
|
| 290 |
-
"workflow_success": True
|
| 291 |
}
|
| 292 |
|
| 293 |
# Conversational summary (if available)
|
|
@@ -301,7 +299,7 @@ class RagBotService:
|
|
| 301 |
prediction=prediction,
|
| 302 |
safety_alerts=safety_alerts,
|
| 303 |
key_drivers=key_drivers,
|
| 304 |
-
recommendations=recommendations
|
| 305 |
)
|
| 306 |
|
| 307 |
# Assemble final response
|
|
@@ -318,17 +316,13 @@ class RagBotService:
|
|
| 318 |
workflow_metadata=workflow_metadata,
|
| 319 |
conversational_summary=conversational_summary,
|
| 320 |
processing_time_ms=processing_time_ms,
|
| 321 |
-
sop_version=workflow_result.get("sop_version", "Baseline")
|
| 322 |
)
|
| 323 |
|
| 324 |
return response
|
| 325 |
|
| 326 |
def _generate_conversational_summary(
|
| 327 |
-
self,
|
| 328 |
-
prediction: Prediction,
|
| 329 |
-
safety_alerts: list,
|
| 330 |
-
key_drivers: list,
|
| 331 |
-
recommendations: Recommendations
|
| 332 |
) -> str:
|
| 333 |
"""Generate a simple conversational summary"""
|
| 334 |
|
|
|
|
| 94 |
biomarkers: dict[str, float],
|
| 95 |
patient_context: dict[str, Any],
|
| 96 |
model_prediction: dict[str, Any],
|
| 97 |
+
extracted_biomarkers: dict[str, float] | None = None,
|
| 98 |
) -> AnalysisResponse:
|
| 99 |
"""
|
| 100 |
Run complete analysis workflow and format full detailed response.
|
| 101 |
+
|
| 102 |
Args:
|
| 103 |
biomarkers: Dictionary of biomarker names to values
|
| 104 |
patient_context: Patient demographic information
|
| 105 |
model_prediction: Disease prediction (disease, confidence, probabilities)
|
| 106 |
extracted_biomarkers: Original extracted biomarkers (for natural language input)
|
| 107 |
+
|
| 108 |
Returns:
|
| 109 |
Complete AnalysisResponse with all details
|
| 110 |
"""
|
|
|
|
| 117 |
try:
|
| 118 |
# Create PatientInput
|
| 119 |
patient_input = PatientInput(
|
| 120 |
+
biomarkers=biomarkers, model_prediction=model_prediction, patient_context=patient_context
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
# Run workflow
|
|
|
|
| 134 |
extracted_biomarkers=extracted_biomarkers,
|
| 135 |
patient_context=patient_context,
|
| 136 |
model_prediction=model_prediction,
|
| 137 |
+
processing_time_ms=processing_time_ms,
|
| 138 |
)
|
| 139 |
|
| 140 |
return response
|
|
|
|
| 151 |
extracted_biomarkers: dict[str, float],
|
| 152 |
patient_context: dict[str, Any],
|
| 153 |
model_prediction: dict[str, Any],
|
| 154 |
+
processing_time_ms: float,
|
| 155 |
) -> AnalysisResponse:
|
| 156 |
"""
|
| 157 |
Format complete detailed response from workflow result.
|
| 158 |
Preserves ALL data from workflow execution.
|
| 159 |
+
|
| 160 |
workflow_result is now the full LangGraph state dict containing:
|
| 161 |
- final_response: dict from response_synthesizer
|
| 162 |
- agent_outputs: list of AgentOutput objects
|
|
|
|
| 172 |
prediction = Prediction(
|
| 173 |
disease=model_prediction["disease"],
|
| 174 |
confidence=model_prediction["confidence"],
|
| 175 |
+
probabilities=model_prediction.get("probabilities", {}),
|
| 176 |
)
|
| 177 |
|
| 178 |
# Biomarker flags: prefer state-level data (BiomarkerFlag objects from validator),
|
|
|
|
| 181 |
if state_flags:
|
| 182 |
biomarker_flags = []
|
| 183 |
for flag in state_flags:
|
| 184 |
+
if hasattr(flag, "model_dump"):
|
| 185 |
biomarker_flags.append(BiomarkerFlag(**flag.model_dump()))
|
| 186 |
elif isinstance(flag, dict):
|
| 187 |
biomarker_flags.append(BiomarkerFlag(**flag))
|
|
|
|
| 199 |
if state_alerts:
|
| 200 |
safety_alerts = []
|
| 201 |
for alert in state_alerts:
|
| 202 |
+
if hasattr(alert, "model_dump"):
|
| 203 |
safety_alerts.append(SafetyAlert(**alert.model_dump()))
|
| 204 |
elif isinstance(alert, dict):
|
| 205 |
safety_alerts.append(SafetyAlert(**alert))
|
|
|
|
| 228 |
disease_explanation = DiseaseExplanation(
|
| 229 |
pathophysiology=disease_exp_data.get("pathophysiology", ""),
|
| 230 |
citations=disease_exp_data.get("citations", []),
|
| 231 |
+
retrieved_chunks=disease_exp_data.get("retrieved_chunks"),
|
| 232 |
)
|
| 233 |
|
| 234 |
# Recommendations from synthesizer
|
|
|
|
| 241 |
immediate_actions=recs_data.get("immediate_actions", []),
|
| 242 |
lifestyle_changes=recs_data.get("lifestyle_changes", []),
|
| 243 |
monitoring=recs_data.get("monitoring", []),
|
| 244 |
+
follow_up=recs_data.get("follow_up"),
|
| 245 |
)
|
| 246 |
|
| 247 |
# Confidence assessment from synthesizer
|
|
|
|
| 252 |
prediction_reliability=conf_data.get("prediction_reliability", "UNKNOWN"),
|
| 253 |
evidence_strength=conf_data.get("evidence_strength", "UNKNOWN"),
|
| 254 |
limitations=conf_data.get("limitations", []),
|
| 255 |
+
reasoning=conf_data.get("reasoning"),
|
| 256 |
)
|
| 257 |
|
| 258 |
# Alternative diagnoses
|
|
|
|
| 268 |
disease_explanation=disease_explanation,
|
| 269 |
recommendations=recommendations,
|
| 270 |
confidence_assessment=confidence_assessment,
|
| 271 |
+
alternative_diagnoses=alternative_diagnoses,
|
| 272 |
)
|
| 273 |
|
| 274 |
# Agent outputs from state (these are src.state.AgentOutput objects)
|
| 275 |
agent_outputs_data = workflow_result.get("agent_outputs", [])
|
| 276 |
agent_outputs = []
|
| 277 |
for agent_out in agent_outputs_data:
|
| 278 |
+
if hasattr(agent_out, "model_dump"):
|
| 279 |
agent_outputs.append(AgentOutput(**agent_out.model_dump()))
|
| 280 |
elif isinstance(agent_out, dict):
|
| 281 |
agent_outputs.append(AgentOutput(**agent_out))
|
|
|
|
| 285 |
"sop_version": workflow_result.get("sop_version"),
|
| 286 |
"processing_timestamp": workflow_result.get("processing_timestamp"),
|
| 287 |
"agents_executed": len(agent_outputs),
|
| 288 |
+
"workflow_success": True,
|
| 289 |
}
|
| 290 |
|
| 291 |
# Conversational summary (if available)
|
|
|
|
| 299 |
prediction=prediction,
|
| 300 |
safety_alerts=safety_alerts,
|
| 301 |
key_drivers=key_drivers,
|
| 302 |
+
recommendations=recommendations,
|
| 303 |
)
|
| 304 |
|
| 305 |
# Assemble final response
|
|
|
|
| 316 |
workflow_metadata=workflow_metadata,
|
| 317 |
conversational_summary=conversational_summary,
|
| 318 |
processing_time_ms=processing_time_ms,
|
| 319 |
+
sop_version=workflow_result.get("sop_version", "Baseline"),
|
| 320 |
)
|
| 321 |
|
| 322 |
return response
|
| 323 |
|
| 324 |
def _generate_conversational_summary(
|
| 325 |
+
self, prediction: Prediction, safety_alerts: list, key_drivers: list, recommendations: Recommendations
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
) -> str:
|
| 327 |
"""Generate a simple conversational summary"""
|
| 328 |
|
archive/evolution/__init__.py
CHANGED
|
@@ -15,15 +15,15 @@ from .director import (
|
|
| 15 |
from .pareto import analyze_improvements, identify_pareto_front, print_pareto_summary, visualize_pareto_frontier
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
]
|
|
|
|
| 15 |
from .pareto import analyze_improvements, identify_pareto_front, print_pareto_summary, visualize_pareto_frontier
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
+
"Diagnosis",
|
| 19 |
+
"EvolvedSOPs",
|
| 20 |
+
"SOPGenePool",
|
| 21 |
+
"SOPMutation",
|
| 22 |
+
"analyze_improvements",
|
| 23 |
+
"identify_pareto_front",
|
| 24 |
+
"performance_diagnostician",
|
| 25 |
+
"print_pareto_summary",
|
| 26 |
+
"run_evolution_cycle",
|
| 27 |
+
"sop_architect",
|
| 28 |
+
"visualize_pareto_frontier",
|
| 29 |
]
|
archive/evolution/director.py
CHANGED
|
@@ -25,7 +25,7 @@ class SOPGenePool:
|
|
| 25 |
sop: ExplanationSOP,
|
| 26 |
evaluation: EvaluationResult,
|
| 27 |
parent_version: int | None = None,
|
| 28 |
-
description: str = ""
|
| 29 |
):
|
| 30 |
"""Add a new SOP to the gene pool"""
|
| 31 |
self.version_counter += 1
|
|
@@ -34,7 +34,7 @@ class SOPGenePool:
|
|
| 34 |
"sop": sop,
|
| 35 |
"evaluation": evaluation,
|
| 36 |
"parent": parent_version,
|
| 37 |
-
"description": description
|
| 38 |
}
|
| 39 |
self.pool.append(entry)
|
| 40 |
self.gene_pool = self.pool # Keep in sync
|
|
@@ -47,7 +47,7 @@ class SOPGenePool:
|
|
| 47 |
def get_by_version(self, version: int) -> dict[str, Any] | None:
|
| 48 |
"""Retrieve specific SOP version"""
|
| 49 |
for entry in self.pool:
|
| 50 |
-
if entry[
|
| 51 |
return entry
|
| 52 |
return None
|
| 53 |
|
|
@@ -56,10 +56,7 @@ class SOPGenePool:
|
|
| 56 |
if not self.pool:
|
| 57 |
return None
|
| 58 |
|
| 59 |
-
best = max(
|
| 60 |
-
self.pool,
|
| 61 |
-
key=lambda x: getattr(x['evaluation'], metric).score
|
| 62 |
-
)
|
| 63 |
return best
|
| 64 |
|
| 65 |
def summary(self):
|
|
@@ -69,10 +66,10 @@ class SOPGenePool:
|
|
| 69 |
print("=" * 80)
|
| 70 |
|
| 71 |
for entry in self.pool:
|
| 72 |
-
v = entry[
|
| 73 |
-
p = entry[
|
| 74 |
-
desc = entry[
|
| 75 |
-
e = entry[
|
| 76 |
|
| 77 |
parent_str = "(Baseline)" if p is None else f"(Child of v{p})"
|
| 78 |
|
|
@@ -88,23 +85,17 @@ class SOPGenePool:
|
|
| 88 |
|
| 89 |
class Diagnosis(BaseModel):
|
| 90 |
"""Structured diagnosis from Performance Diagnostician"""
|
|
|
|
| 91 |
primary_weakness: Literal[
|
| 92 |
-
|
| 93 |
-
'evidence_grounding',
|
| 94 |
-
'actionability',
|
| 95 |
-
'clarity',
|
| 96 |
-
'safety_completeness'
|
| 97 |
]
|
| 98 |
-
root_cause_analysis: str = Field(
|
| 99 |
-
|
| 100 |
-
)
|
| 101 |
-
recommendation: str = Field(
|
| 102 |
-
description="High-level recommendation to fix the problem"
|
| 103 |
-
)
|
| 104 |
|
| 105 |
|
| 106 |
class SOPMutation(BaseModel):
|
| 107 |
"""Single mutated SOP with description"""
|
|
|
|
| 108 |
description: str = Field(description="Brief description of mutation strategy")
|
| 109 |
# SOP fields from ExplanationSOP
|
| 110 |
biomarker_analyzer_threshold: float = 0.15
|
|
@@ -121,6 +112,7 @@ class SOPMutation(BaseModel):
|
|
| 121 |
|
| 122 |
class EvolvedSOPs(BaseModel):
|
| 123 |
"""Container for mutated SOPs from Architect"""
|
|
|
|
| 124 |
mutations: list[SOPMutation]
|
| 125 |
|
| 126 |
|
|
@@ -135,19 +127,19 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 135 |
|
| 136 |
# Find lowest score programmatically (no LLM needed)
|
| 137 |
scores = {
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
}
|
| 144 |
|
| 145 |
reasonings = {
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
}
|
| 152 |
|
| 153 |
primary_weakness = min(scores, key=scores.get)
|
|
@@ -156,25 +148,25 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 156 |
|
| 157 |
# Generate detailed root cause analysis
|
| 158 |
root_cause_map = {
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
}
|
| 165 |
|
| 166 |
recommendation_map = {
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
}
|
| 173 |
|
| 174 |
diagnosis = Diagnosis(
|
| 175 |
primary_weakness=primary_weakness,
|
| 176 |
root_cause_analysis=root_cause_map[primary_weakness],
|
| 177 |
-
recommendation=recommendation_map[primary_weakness]
|
| 178 |
)
|
| 179 |
|
| 180 |
print("\nβ Diagnosis complete")
|
|
@@ -184,10 +176,7 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 184 |
return diagnosis
|
| 185 |
|
| 186 |
|
| 187 |
-
def sop_architect(
|
| 188 |
-
diagnosis: Diagnosis,
|
| 189 |
-
current_sop: ExplanationSOP
|
| 190 |
-
) -> EvolvedSOPs:
|
| 191 |
"""
|
| 192 |
Generates targeted SOP mutations to address diagnosed weakness.
|
| 193 |
Uses programmatic generation for reliability.
|
|
@@ -200,116 +189,116 @@ def sop_architect(
|
|
| 200 |
weakness = diagnosis.primary_weakness
|
| 201 |
|
| 202 |
# Generate mutations based on weakness type
|
| 203 |
-
if weakness ==
|
| 204 |
mut1 = SOPMutation(
|
| 205 |
disease_explainer_k=max(3, current_sop.disease_explainer_k - 1),
|
| 206 |
linker_retrieval_k=max(2, current_sop.linker_retrieval_k - 1),
|
| 207 |
guideline_retrieval_k=max(2, current_sop.guideline_retrieval_k - 1),
|
| 208 |
-
explainer_detail_level=
|
| 209 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 210 |
use_guideline_agent=current_sop.use_guideline_agent,
|
| 211 |
include_alternative_diagnoses=False,
|
| 212 |
require_pdf_citations=current_sop.require_pdf_citations,
|
| 213 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 214 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 215 |
-
description="Reduce retrieval depth and use concise style for clarity"
|
| 216 |
)
|
| 217 |
mut2 = SOPMutation(
|
| 218 |
disease_explainer_k=current_sop.disease_explainer_k,
|
| 219 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 220 |
guideline_retrieval_k=current_sop.guideline_retrieval_k,
|
| 221 |
-
explainer_detail_level=
|
| 222 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 223 |
use_guideline_agent=current_sop.use_guideline_agent,
|
| 224 |
include_alternative_diagnoses=True,
|
| 225 |
require_pdf_citations=False,
|
| 226 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 227 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 228 |
-
description="Balanced detail with fewer citations for readability"
|
| 229 |
)
|
| 230 |
|
| 231 |
-
elif weakness ==
|
| 232 |
mut1 = SOPMutation(
|
| 233 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
| 234 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 235 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 236 |
-
explainer_detail_level=
|
| 237 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 238 |
use_guideline_agent=True,
|
| 239 |
include_alternative_diagnoses=current_sop.include_alternative_diagnoses,
|
| 240 |
require_pdf_citations=True,
|
| 241 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 242 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 243 |
-
description="Maximum RAG depth with strict citation requirements"
|
| 244 |
)
|
| 245 |
mut2 = SOPMutation(
|
| 246 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
| 247 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 248 |
guideline_retrieval_k=current_sop.guideline_retrieval_k,
|
| 249 |
-
explainer_detail_level=
|
| 250 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 251 |
use_guideline_agent=True,
|
| 252 |
include_alternative_diagnoses=current_sop.include_alternative_diagnoses,
|
| 253 |
require_pdf_citations=True,
|
| 254 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 255 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 256 |
-
description="Moderate RAG increase with citation enforcement"
|
| 257 |
)
|
| 258 |
|
| 259 |
-
elif weakness ==
|
| 260 |
mut1 = SOPMutation(
|
| 261 |
disease_explainer_k=current_sop.disease_explainer_k,
|
| 262 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 263 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 2),
|
| 264 |
-
explainer_detail_level=
|
| 265 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 266 |
use_guideline_agent=True,
|
| 267 |
include_alternative_diagnoses=current_sop.include_alternative_diagnoses,
|
| 268 |
require_pdf_citations=True,
|
| 269 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 270 |
-
critical_value_alert_mode=
|
| 271 |
-
description="Increase guideline retrieval for actionable recommendations"
|
| 272 |
)
|
| 273 |
mut2 = SOPMutation(
|
| 274 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
| 275 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 276 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 277 |
-
explainer_detail_level=
|
| 278 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 279 |
use_guideline_agent=True,
|
| 280 |
include_alternative_diagnoses=True,
|
| 281 |
require_pdf_citations=True,
|
| 282 |
use_confidence_assessor=True,
|
| 283 |
-
critical_value_alert_mode=
|
| 284 |
-
description="Comprehensive approach with all agents enabled"
|
| 285 |
)
|
| 286 |
|
| 287 |
-
elif weakness ==
|
| 288 |
mut1 = SOPMutation(
|
| 289 |
disease_explainer_k=10,
|
| 290 |
linker_retrieval_k=5,
|
| 291 |
guideline_retrieval_k=5,
|
| 292 |
-
explainer_detail_level=
|
| 293 |
biomarker_analyzer_threshold=max(0.10, current_sop.biomarker_analyzer_threshold - 0.05),
|
| 294 |
use_guideline_agent=True,
|
| 295 |
include_alternative_diagnoses=True,
|
| 296 |
require_pdf_citations=True,
|
| 297 |
use_confidence_assessor=True,
|
| 298 |
-
critical_value_alert_mode=
|
| 299 |
-
description="Maximum RAG depth with strict thresholds for accuracy"
|
| 300 |
)
|
| 301 |
mut2 = SOPMutation(
|
| 302 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
| 303 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 304 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 305 |
-
explainer_detail_level=
|
| 306 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 307 |
use_guideline_agent=True,
|
| 308 |
include_alternative_diagnoses=True,
|
| 309 |
require_pdf_citations=True,
|
| 310 |
use_confidence_assessor=True,
|
| 311 |
-
critical_value_alert_mode=
|
| 312 |
-
description="High RAG depth with comprehensive detail"
|
| 313 |
)
|
| 314 |
|
| 315 |
else: # safety_completeness
|
|
@@ -317,27 +306,27 @@ def sop_architect(
|
|
| 317 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
| 318 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 319 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 2),
|
| 320 |
-
explainer_detail_level=
|
| 321 |
biomarker_analyzer_threshold=max(0.10, current_sop.biomarker_analyzer_threshold - 0.03),
|
| 322 |
use_guideline_agent=True,
|
| 323 |
include_alternative_diagnoses=True,
|
| 324 |
require_pdf_citations=True,
|
| 325 |
use_confidence_assessor=True,
|
| 326 |
-
critical_value_alert_mode=
|
| 327 |
-
description="Strict safety mode with enhanced guidelines"
|
| 328 |
)
|
| 329 |
mut2 = SOPMutation(
|
| 330 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
| 331 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 332 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 333 |
-
explainer_detail_level=
|
| 334 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 335 |
use_guideline_agent=True,
|
| 336 |
include_alternative_diagnoses=True,
|
| 337 |
require_pdf_citations=True,
|
| 338 |
use_confidence_assessor=True,
|
| 339 |
-
critical_value_alert_mode=
|
| 340 |
-
description="Maximum coverage with all safety features"
|
| 341 |
)
|
| 342 |
|
| 343 |
evolved = EvolvedSOPs(mutations=[mut1, mut2])
|
|
@@ -351,10 +340,7 @@ def sop_architect(
|
|
| 351 |
|
| 352 |
|
| 353 |
def run_evolution_cycle(
|
| 354 |
-
gene_pool: SOPGenePool,
|
| 355 |
-
patient_input: Any,
|
| 356 |
-
workflow_graph: Any,
|
| 357 |
-
evaluation_func: Callable
|
| 358 |
) -> list[dict[str, Any]]:
|
| 359 |
"""
|
| 360 |
Executes one complete evolution cycle:
|
|
@@ -362,7 +348,7 @@ def run_evolution_cycle(
|
|
| 362 |
2. Generate mutations
|
| 363 |
3. Test each mutation
|
| 364 |
4. Add to gene pool
|
| 365 |
-
|
| 366 |
Returns: List of new entries added to pool
|
| 367 |
"""
|
| 368 |
print("\n" + "=" * 80)
|
|
@@ -374,9 +360,9 @@ def run_evolution_cycle(
|
|
| 374 |
if not current_best:
|
| 375 |
raise ValueError("Gene pool is empty. Add baseline SOP first.")
|
| 376 |
|
| 377 |
-
parent_sop = current_best[
|
| 378 |
-
parent_eval = current_best[
|
| 379 |
-
parent_version = current_best[
|
| 380 |
|
| 381 |
print(f"\nImproving upon SOP v{parent_version}")
|
| 382 |
|
|
@@ -395,11 +381,12 @@ def run_evolution_cycle(
|
|
| 395 |
|
| 396 |
# Convert SOPMutation to ExplanationSOP
|
| 397 |
mutant_sop_dict = mutant_sop_model.model_dump()
|
| 398 |
-
description = mutant_sop_dict.pop(
|
| 399 |
mutant_sop = ExplanationSOP(**mutant_sop_dict)
|
| 400 |
|
| 401 |
# Run workflow with mutated SOP
|
| 402 |
from datetime import datetime
|
|
|
|
| 403 |
graph_input = {
|
| 404 |
"patient_biomarkers": patient_input.biomarkers,
|
| 405 |
"model_prediction": patient_input.model_prediction,
|
|
@@ -412,7 +399,7 @@ def run_evolution_cycle(
|
|
| 412 |
"biomarker_analysis": None,
|
| 413 |
"final_response": None,
|
| 414 |
"processing_timestamp": datetime.now().isoformat(),
|
| 415 |
-
"sop_version": description
|
| 416 |
}
|
| 417 |
|
| 418 |
try:
|
|
@@ -420,24 +407,15 @@ def run_evolution_cycle(
|
|
| 420 |
|
| 421 |
# Evaluate output
|
| 422 |
evaluation = evaluation_func(
|
| 423 |
-
final_response=final_state[
|
| 424 |
-
agent_outputs=final_state[
|
| 425 |
-
biomarkers=patient_input.biomarkers
|
| 426 |
)
|
| 427 |
|
| 428 |
# Add to gene pool
|
| 429 |
-
gene_pool.add(
|
| 430 |
-
sop=mutant_sop,
|
| 431 |
-
evaluation=evaluation,
|
| 432 |
-
parent_version=parent_version,
|
| 433 |
-
description=description
|
| 434 |
-
)
|
| 435 |
|
| 436 |
-
new_entries.append({
|
| 437 |
-
"sop": mutant_sop,
|
| 438 |
-
"evaluation": evaluation,
|
| 439 |
-
"description": description
|
| 440 |
-
})
|
| 441 |
except Exception as e:
|
| 442 |
print(f"β Mutation {i} failed: {e}")
|
| 443 |
continue
|
|
|
|
| 25 |
sop: ExplanationSOP,
|
| 26 |
evaluation: EvaluationResult,
|
| 27 |
parent_version: int | None = None,
|
| 28 |
+
description: str = "",
|
| 29 |
):
|
| 30 |
"""Add a new SOP to the gene pool"""
|
| 31 |
self.version_counter += 1
|
|
|
|
| 34 |
"sop": sop,
|
| 35 |
"evaluation": evaluation,
|
| 36 |
"parent": parent_version,
|
| 37 |
+
"description": description,
|
| 38 |
}
|
| 39 |
self.pool.append(entry)
|
| 40 |
self.gene_pool = self.pool # Keep in sync
|
|
|
|
| 47 |
def get_by_version(self, version: int) -> dict[str, Any] | None:
|
| 48 |
"""Retrieve specific SOP version"""
|
| 49 |
for entry in self.pool:
|
| 50 |
+
if entry["version"] == version:
|
| 51 |
return entry
|
| 52 |
return None
|
| 53 |
|
|
|
|
| 56 |
if not self.pool:
|
| 57 |
return None
|
| 58 |
|
| 59 |
+
best = max(self.pool, key=lambda x: getattr(x["evaluation"], metric).score)
|
|
|
|
|
|
|
|
|
|
| 60 |
return best
|
| 61 |
|
| 62 |
def summary(self):
|
|
|
|
| 66 |
print("=" * 80)
|
| 67 |
|
| 68 |
for entry in self.pool:
|
| 69 |
+
v = entry["version"]
|
| 70 |
+
p = entry["parent"]
|
| 71 |
+
desc = entry["description"]
|
| 72 |
+
e = entry["evaluation"]
|
| 73 |
|
| 74 |
parent_str = "(Baseline)" if p is None else f"(Child of v{p})"
|
| 75 |
|
|
|
|
| 85 |
|
| 86 |
class Diagnosis(BaseModel):
|
| 87 |
"""Structured diagnosis from Performance Diagnostician"""
|
| 88 |
+
|
| 89 |
primary_weakness: Literal[
|
| 90 |
+
"clinical_accuracy", "evidence_grounding", "actionability", "clarity", "safety_completeness"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
]
|
| 92 |
+
root_cause_analysis: str = Field(description="Detailed analysis of why weakness occurred")
|
| 93 |
+
recommendation: str = Field(description="High-level recommendation to fix the problem")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
class SOPMutation(BaseModel):
|
| 97 |
"""Single mutated SOP with description"""
|
| 98 |
+
|
| 99 |
description: str = Field(description="Brief description of mutation strategy")
|
| 100 |
# SOP fields from ExplanationSOP
|
| 101 |
biomarker_analyzer_threshold: float = 0.15
|
|
|
|
| 112 |
|
| 113 |
class EvolvedSOPs(BaseModel):
|
| 114 |
"""Container for mutated SOPs from Architect"""
|
| 115 |
+
|
| 116 |
mutations: list[SOPMutation]
|
| 117 |
|
| 118 |
|
|
|
|
| 127 |
|
| 128 |
# Find lowest score programmatically (no LLM needed)
|
| 129 |
scores = {
|
| 130 |
+
"clinical_accuracy": evaluation.clinical_accuracy.score,
|
| 131 |
+
"evidence_grounding": evaluation.evidence_grounding.score,
|
| 132 |
+
"actionability": evaluation.actionability.score,
|
| 133 |
+
"clarity": evaluation.clarity.score,
|
| 134 |
+
"safety_completeness": evaluation.safety_completeness.score,
|
| 135 |
}
|
| 136 |
|
| 137 |
reasonings = {
|
| 138 |
+
"clinical_accuracy": evaluation.clinical_accuracy.reasoning,
|
| 139 |
+
"evidence_grounding": evaluation.evidence_grounding.reasoning,
|
| 140 |
+
"actionability": evaluation.actionability.reasoning,
|
| 141 |
+
"clarity": evaluation.clarity.reasoning,
|
| 142 |
+
"safety_completeness": evaluation.safety_completeness.reasoning,
|
| 143 |
}
|
| 144 |
|
| 145 |
primary_weakness = min(scores, key=scores.get)
|
|
|
|
| 148 |
|
| 149 |
# Generate detailed root cause analysis
|
| 150 |
root_cause_map = {
|
| 151 |
+
"clinical_accuracy": f"Clinical accuracy score ({weakness_score:.2f}) indicates potential issues with medical interpretations. {weakness_reasoning[:200]}",
|
| 152 |
+
"evidence_grounding": f"Evidence grounding score ({weakness_score:.2f}) suggests insufficient citations. {weakness_reasoning[:200]}",
|
| 153 |
+
"actionability": f"Actionability score ({weakness_score:.2f}) indicates recommendations lack specificity. {weakness_reasoning[:200]}",
|
| 154 |
+
"clarity": f"Clarity score ({weakness_score:.2f}) suggests readability issues. {weakness_reasoning[:200]}",
|
| 155 |
+
"safety_completeness": f"Safety score ({weakness_score:.2f}) indicates missing risk discussions. {weakness_reasoning[:200]}",
|
| 156 |
}
|
| 157 |
|
| 158 |
recommendation_map = {
|
| 159 |
+
"clinical_accuracy": "Increase RAG depth to access more authoritative medical sources.",
|
| 160 |
+
"evidence_grounding": "Enforce strict citation requirements and increase RAG depth.",
|
| 161 |
+
"actionability": "Make recommendations more specific with concrete action items.",
|
| 162 |
+
"clarity": "Simplify language and reduce technical jargon for better readability.",
|
| 163 |
+
"safety_completeness": "Add explicit safety warnings and ensure complete risk coverage.",
|
| 164 |
}
|
| 165 |
|
| 166 |
diagnosis = Diagnosis(
|
| 167 |
primary_weakness=primary_weakness,
|
| 168 |
root_cause_analysis=root_cause_map[primary_weakness],
|
| 169 |
+
recommendation=recommendation_map[primary_weakness],
|
| 170 |
)
|
| 171 |
|
| 172 |
print("\nβ Diagnosis complete")
|
|
|
|
| 176 |
return diagnosis
|
| 177 |
|
| 178 |
|
| 179 |
+
def sop_architect(diagnosis: Diagnosis, current_sop: ExplanationSOP) -> EvolvedSOPs:
|
|
|
|
|
|
|
|
|
|
| 180 |
"""
|
| 181 |
Generates targeted SOP mutations to address diagnosed weakness.
|
| 182 |
Uses programmatic generation for reliability.
|
|
|
|
| 189 |
weakness = diagnosis.primary_weakness
|
| 190 |
|
| 191 |
# Generate mutations based on weakness type
|
| 192 |
+
if weakness == "clarity":
|
| 193 |
mut1 = SOPMutation(
|
| 194 |
disease_explainer_k=max(3, current_sop.disease_explainer_k - 1),
|
| 195 |
linker_retrieval_k=max(2, current_sop.linker_retrieval_k - 1),
|
| 196 |
guideline_retrieval_k=max(2, current_sop.guideline_retrieval_k - 1),
|
| 197 |
+
explainer_detail_level="concise",
|
| 198 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 199 |
use_guideline_agent=current_sop.use_guideline_agent,
|
| 200 |
include_alternative_diagnoses=False,
|
| 201 |
require_pdf_citations=current_sop.require_pdf_citations,
|
| 202 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 203 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 204 |
+
description="Reduce retrieval depth and use concise style for clarity",
|
| 205 |
)
|
| 206 |
mut2 = SOPMutation(
|
| 207 |
disease_explainer_k=current_sop.disease_explainer_k,
|
| 208 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 209 |
guideline_retrieval_k=current_sop.guideline_retrieval_k,
|
| 210 |
+
explainer_detail_level="detailed",
|
| 211 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 212 |
use_guideline_agent=current_sop.use_guideline_agent,
|
| 213 |
include_alternative_diagnoses=True,
|
| 214 |
require_pdf_citations=False,
|
| 215 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 216 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 217 |
+
description="Balanced detail with fewer citations for readability",
|
| 218 |
)
|
| 219 |
|
| 220 |
+
elif weakness == "evidence_grounding":
|
| 221 |
mut1 = SOPMutation(
|
| 222 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
| 223 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 224 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 225 |
+
explainer_detail_level="comprehensive",
|
| 226 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 227 |
use_guideline_agent=True,
|
| 228 |
include_alternative_diagnoses=current_sop.include_alternative_diagnoses,
|
| 229 |
require_pdf_citations=True,
|
| 230 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 231 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 232 |
+
description="Maximum RAG depth with strict citation requirements",
|
| 233 |
)
|
| 234 |
mut2 = SOPMutation(
|
| 235 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
| 236 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 237 |
guideline_retrieval_k=current_sop.guideline_retrieval_k,
|
| 238 |
+
explainer_detail_level="detailed",
|
| 239 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 240 |
use_guideline_agent=True,
|
| 241 |
include_alternative_diagnoses=current_sop.include_alternative_diagnoses,
|
| 242 |
require_pdf_citations=True,
|
| 243 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 244 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 245 |
+
description="Moderate RAG increase with citation enforcement",
|
| 246 |
)
|
| 247 |
|
| 248 |
+
elif weakness == "actionability":
|
| 249 |
mut1 = SOPMutation(
|
| 250 |
disease_explainer_k=current_sop.disease_explainer_k,
|
| 251 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 252 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 2),
|
| 253 |
+
explainer_detail_level="comprehensive",
|
| 254 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 255 |
use_guideline_agent=True,
|
| 256 |
include_alternative_diagnoses=current_sop.include_alternative_diagnoses,
|
| 257 |
require_pdf_citations=True,
|
| 258 |
use_confidence_assessor=current_sop.use_confidence_assessor,
|
| 259 |
+
critical_value_alert_mode="strict",
|
| 260 |
+
description="Increase guideline retrieval for actionable recommendations",
|
| 261 |
)
|
| 262 |
mut2 = SOPMutation(
|
| 263 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
| 264 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 265 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 266 |
+
explainer_detail_level="detailed",
|
| 267 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 268 |
use_guideline_agent=True,
|
| 269 |
include_alternative_diagnoses=True,
|
| 270 |
require_pdf_citations=True,
|
| 271 |
use_confidence_assessor=True,
|
| 272 |
+
critical_value_alert_mode="strict",
|
| 273 |
+
description="Comprehensive approach with all agents enabled",
|
| 274 |
)
|
| 275 |
|
| 276 |
+
elif weakness == "clinical_accuracy":
|
| 277 |
mut1 = SOPMutation(
|
| 278 |
disease_explainer_k=10,
|
| 279 |
linker_retrieval_k=5,
|
| 280 |
guideline_retrieval_k=5,
|
| 281 |
+
explainer_detail_level="comprehensive",
|
| 282 |
biomarker_analyzer_threshold=max(0.10, current_sop.biomarker_analyzer_threshold - 0.05),
|
| 283 |
use_guideline_agent=True,
|
| 284 |
include_alternative_diagnoses=True,
|
| 285 |
require_pdf_citations=True,
|
| 286 |
use_confidence_assessor=True,
|
| 287 |
+
critical_value_alert_mode="strict",
|
| 288 |
+
description="Maximum RAG depth with strict thresholds for accuracy",
|
| 289 |
)
|
| 290 |
mut2 = SOPMutation(
|
| 291 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
| 292 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 293 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 294 |
+
explainer_detail_level="comprehensive",
|
| 295 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 296 |
use_guideline_agent=True,
|
| 297 |
include_alternative_diagnoses=True,
|
| 298 |
require_pdf_citations=True,
|
| 299 |
use_confidence_assessor=True,
|
| 300 |
+
critical_value_alert_mode="strict",
|
| 301 |
+
description="High RAG depth with comprehensive detail",
|
| 302 |
)
|
| 303 |
|
| 304 |
else: # safety_completeness
|
|
|
|
| 306 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
| 307 |
linker_retrieval_k=current_sop.linker_retrieval_k,
|
| 308 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 2),
|
| 309 |
+
explainer_detail_level="comprehensive",
|
| 310 |
biomarker_analyzer_threshold=max(0.10, current_sop.biomarker_analyzer_threshold - 0.03),
|
| 311 |
use_guideline_agent=True,
|
| 312 |
include_alternative_diagnoses=True,
|
| 313 |
require_pdf_citations=True,
|
| 314 |
use_confidence_assessor=True,
|
| 315 |
+
critical_value_alert_mode="strict",
|
| 316 |
+
description="Strict safety mode with enhanced guidelines",
|
| 317 |
)
|
| 318 |
mut2 = SOPMutation(
|
| 319 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
| 320 |
linker_retrieval_k=min(5, current_sop.linker_retrieval_k + 1),
|
| 321 |
guideline_retrieval_k=min(5, current_sop.guideline_retrieval_k + 1),
|
| 322 |
+
explainer_detail_level="comprehensive",
|
| 323 |
biomarker_analyzer_threshold=current_sop.biomarker_analyzer_threshold,
|
| 324 |
use_guideline_agent=True,
|
| 325 |
include_alternative_diagnoses=True,
|
| 326 |
require_pdf_citations=True,
|
| 327 |
use_confidence_assessor=True,
|
| 328 |
+
critical_value_alert_mode="strict",
|
| 329 |
+
description="Maximum coverage with all safety features",
|
| 330 |
)
|
| 331 |
|
| 332 |
evolved = EvolvedSOPs(mutations=[mut1, mut2])
|
|
|
|
| 340 |
|
| 341 |
|
| 342 |
def run_evolution_cycle(
|
| 343 |
+
gene_pool: SOPGenePool, patient_input: Any, workflow_graph: Any, evaluation_func: Callable
|
|
|
|
|
|
|
|
|
|
| 344 |
) -> list[dict[str, Any]]:
|
| 345 |
"""
|
| 346 |
Executes one complete evolution cycle:
|
|
|
|
| 348 |
2. Generate mutations
|
| 349 |
3. Test each mutation
|
| 350 |
4. Add to gene pool
|
| 351 |
+
|
| 352 |
Returns: List of new entries added to pool
|
| 353 |
"""
|
| 354 |
print("\n" + "=" * 80)
|
|
|
|
| 360 |
if not current_best:
|
| 361 |
raise ValueError("Gene pool is empty. Add baseline SOP first.")
|
| 362 |
|
| 363 |
+
parent_sop = current_best["sop"]
|
| 364 |
+
parent_eval = current_best["evaluation"]
|
| 365 |
+
parent_version = current_best["version"]
|
| 366 |
|
| 367 |
print(f"\nImproving upon SOP v{parent_version}")
|
| 368 |
|
|
|
|
| 381 |
|
| 382 |
# Convert SOPMutation to ExplanationSOP
|
| 383 |
mutant_sop_dict = mutant_sop_model.model_dump()
|
| 384 |
+
description = mutant_sop_dict.pop("description")
|
| 385 |
mutant_sop = ExplanationSOP(**mutant_sop_dict)
|
| 386 |
|
| 387 |
# Run workflow with mutated SOP
|
| 388 |
from datetime import datetime
|
| 389 |
+
|
| 390 |
graph_input = {
|
| 391 |
"patient_biomarkers": patient_input.biomarkers,
|
| 392 |
"model_prediction": patient_input.model_prediction,
|
|
|
|
| 399 |
"biomarker_analysis": None,
|
| 400 |
"final_response": None,
|
| 401 |
"processing_timestamp": datetime.now().isoformat(),
|
| 402 |
+
"sop_version": description,
|
| 403 |
}
|
| 404 |
|
| 405 |
try:
|
|
|
|
| 407 |
|
| 408 |
# Evaluate output
|
| 409 |
evaluation = evaluation_func(
|
| 410 |
+
final_response=final_state["final_response"],
|
| 411 |
+
agent_outputs=final_state["agent_outputs"],
|
| 412 |
+
biomarkers=patient_input.biomarkers,
|
| 413 |
)
|
| 414 |
|
| 415 |
# Add to gene pool
|
| 416 |
+
gene_pool.add(sop=mutant_sop, evaluation=evaluation, parent_version=parent_version, description=description)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
+
new_entries.append({"sop": mutant_sop, "evaluation": evaluation, "description": description})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
except Exception as e:
|
| 420 |
print(f"β Mutation {i} failed: {e}")
|
| 421 |
continue
|
archive/evolution/pareto.py
CHANGED
|
@@ -8,14 +8,14 @@ from typing import Any
|
|
| 8 |
import matplotlib
|
| 9 |
import numpy as np
|
| 10 |
|
| 11 |
-
matplotlib.use(
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
|
| 14 |
|
| 15 |
def identify_pareto_front(gene_pool_entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 16 |
"""
|
| 17 |
Identifies non-dominated solutions (Pareto Frontier).
|
| 18 |
-
|
| 19 |
A solution is dominated if another solution is:
|
| 20 |
- Better or equal on ALL metrics
|
| 21 |
- Strictly better on AT LEAST ONE metric
|
|
@@ -26,14 +26,14 @@ def identify_pareto_front(gene_pool_entries: list[dict[str, Any]]) -> list[dict[
|
|
| 26 |
is_dominated = False
|
| 27 |
|
| 28 |
# Get candidate's 5D score vector
|
| 29 |
-
cand_scores = np.array(candidate[
|
| 30 |
|
| 31 |
for j, other in enumerate(gene_pool_entries):
|
| 32 |
if i == j:
|
| 33 |
continue
|
| 34 |
|
| 35 |
# Get other solution's 5D vector
|
| 36 |
-
other_scores = np.array(other[
|
| 37 |
|
| 38 |
# Check domination: other >= candidate on ALL, other > candidate on SOME
|
| 39 |
if np.all(other_scores >= cand_scores) and np.any(other_scores > cand_scores):
|
|
@@ -61,75 +61,75 @@ def visualize_pareto_frontier(pareto_front: list[dict[str, Any]]):
|
|
| 61 |
# --- Plot 1: Bar Chart (since pandas might not be available) ---
|
| 62 |
ax1 = plt.subplot(1, 2, 1)
|
| 63 |
|
| 64 |
-
metrics = [
|
| 65 |
x = np.arange(len(metrics))
|
| 66 |
width = 0.8 / len(pareto_front)
|
| 67 |
|
| 68 |
for idx, entry in enumerate(pareto_front):
|
| 69 |
-
e = entry[
|
| 70 |
scores = [
|
| 71 |
e.clinical_accuracy.score,
|
| 72 |
e.evidence_grounding.score,
|
| 73 |
e.actionability.score,
|
| 74 |
e.clarity.score,
|
| 75 |
-
e.safety_completeness.score
|
| 76 |
]
|
| 77 |
|
| 78 |
offset = (idx - len(pareto_front) / 2) * width + width / 2
|
| 79 |
label = f"SOP v{entry['version']}"
|
| 80 |
ax1.bar(x + offset, scores, width, label=label, alpha=0.8)
|
| 81 |
|
| 82 |
-
ax1.set_xlabel(
|
| 83 |
-
ax1.set_ylabel(
|
| 84 |
-
ax1.set_title(
|
| 85 |
ax1.set_xticks(x)
|
| 86 |
ax1.set_xticklabels(metrics, fontsize=10)
|
| 87 |
ax1.set_ylim(0, 1.0)
|
| 88 |
-
ax1.legend(loc=
|
| 89 |
-
ax1.grid(True, alpha=0.3, axis=
|
| 90 |
|
| 91 |
# --- Plot 2: Radar Chart ---
|
| 92 |
-
ax2 = plt.subplot(1, 2, 2, projection=
|
| 93 |
|
| 94 |
-
categories = [
|
| 95 |
-
'Actionability', 'Clarity', 'Safety']
|
| 96 |
num_vars = len(categories)
|
| 97 |
|
| 98 |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
| 99 |
angles += angles[:1]
|
| 100 |
|
| 101 |
for entry in pareto_front:
|
| 102 |
-
e = entry[
|
| 103 |
values = [
|
| 104 |
e.clinical_accuracy.score,
|
| 105 |
e.evidence_grounding.score,
|
| 106 |
e.actionability.score,
|
| 107 |
e.clarity.score,
|
| 108 |
-
e.safety_completeness.score
|
| 109 |
]
|
| 110 |
values += values[:1]
|
| 111 |
|
| 112 |
-
desc = entry.get(
|
| 113 |
label = f"SOP v{entry['version']}: {desc}"
|
| 114 |
-
ax2.plot(angles, values,
|
| 115 |
ax2.fill(angles, values, alpha=0.15)
|
| 116 |
|
| 117 |
ax2.set_xticks(angles[:-1])
|
| 118 |
ax2.set_xticklabels(categories, size=10)
|
| 119 |
ax2.set_ylim(0, 1)
|
| 120 |
-
ax2.set_title(
|
| 121 |
-
ax2.legend(loc=
|
| 122 |
ax2.grid(True)
|
| 123 |
|
| 124 |
plt.tight_layout()
|
| 125 |
|
| 126 |
# Create data directory if it doesn't exist
|
| 127 |
from pathlib import Path
|
| 128 |
-
|
|
|
|
| 129 |
data_dir.mkdir(exist_ok=True)
|
| 130 |
|
| 131 |
-
output_path = data_dir /
|
| 132 |
-
plt.savefig(output_path, dpi=300, bbox_inches=
|
| 133 |
plt.close()
|
| 134 |
|
| 135 |
print(f"\nβ Visualization saved to: {output_path}")
|
|
@@ -144,10 +144,10 @@ def print_pareto_summary(pareto_front: list[dict[str, Any]]):
|
|
| 144 |
print(f"\nFound {len(pareto_front)} optimal (non-dominated) solutions:\n")
|
| 145 |
|
| 146 |
for entry in pareto_front:
|
| 147 |
-
v = entry[
|
| 148 |
-
p = entry.get(
|
| 149 |
-
desc = entry.get(
|
| 150 |
-
e = entry[
|
| 151 |
|
| 152 |
print(f"SOP v{v} {f'(Child of v{p})' if p else '(Baseline)'}")
|
| 153 |
print(f" Description: {desc}")
|
|
@@ -176,7 +176,7 @@ def analyze_improvements(gene_pool_entries: list[dict[str, Any]]):
|
|
| 176 |
return
|
| 177 |
|
| 178 |
baseline = gene_pool_entries[0]
|
| 179 |
-
baseline_scores = np.array(baseline[
|
| 180 |
|
| 181 |
print("\n" + "=" * 80)
|
| 182 |
print("IMPROVEMENT ANALYSIS")
|
|
@@ -187,7 +187,7 @@ def analyze_improvements(gene_pool_entries: list[dict[str, Any]]):
|
|
| 187 |
|
| 188 |
improvements_found = False
|
| 189 |
for entry in gene_pool_entries[1:]:
|
| 190 |
-
scores = np.array(entry[
|
| 191 |
avg_score = np.mean(scores)
|
| 192 |
baseline_avg = np.mean(baseline_scores)
|
| 193 |
|
|
@@ -199,8 +199,13 @@ def analyze_improvements(gene_pool_entries: list[dict[str, Any]]):
|
|
| 199 |
print(f" Average Score: {avg_score:.3f} (+{improvement_pct:.1f}% vs baseline)")
|
| 200 |
|
| 201 |
# Show per-metric improvements
|
| 202 |
-
metric_names = [
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
for i, (name, score, baseline_score) in enumerate(zip(metric_names, scores, baseline_scores)):
|
| 205 |
diff = score - baseline_score
|
| 206 |
if abs(diff) > 0.01: # Show significant changes
|
|
|
|
| 8 |
import matplotlib
|
| 9 |
import numpy as np
|
| 10 |
|
| 11 |
+
matplotlib.use("Agg") # Use non-interactive backend
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
|
| 14 |
|
| 15 |
def identify_pareto_front(gene_pool_entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 16 |
"""
|
| 17 |
Identifies non-dominated solutions (Pareto Frontier).
|
| 18 |
+
|
| 19 |
A solution is dominated if another solution is:
|
| 20 |
- Better or equal on ALL metrics
|
| 21 |
- Strictly better on AT LEAST ONE metric
|
|
|
|
| 26 |
is_dominated = False
|
| 27 |
|
| 28 |
# Get candidate's 5D score vector
|
| 29 |
+
cand_scores = np.array(candidate["evaluation"].to_vector())
|
| 30 |
|
| 31 |
for j, other in enumerate(gene_pool_entries):
|
| 32 |
if i == j:
|
| 33 |
continue
|
| 34 |
|
| 35 |
# Get other solution's 5D vector
|
| 36 |
+
other_scores = np.array(other["evaluation"].to_vector())
|
| 37 |
|
| 38 |
# Check domination: other >= candidate on ALL, other > candidate on SOME
|
| 39 |
if np.all(other_scores >= cand_scores) and np.any(other_scores > cand_scores):
|
|
|
|
| 61 |
# --- Plot 1: Bar Chart (since pandas might not be available) ---
|
| 62 |
ax1 = plt.subplot(1, 2, 1)
|
| 63 |
|
| 64 |
+
metrics = ["Clinical\nAccuracy", "Evidence\nGrounding", "Actionability", "Clarity", "Safety"]
|
| 65 |
x = np.arange(len(metrics))
|
| 66 |
width = 0.8 / len(pareto_front)
|
| 67 |
|
| 68 |
for idx, entry in enumerate(pareto_front):
|
| 69 |
+
e = entry["evaluation"]
|
| 70 |
scores = [
|
| 71 |
e.clinical_accuracy.score,
|
| 72 |
e.evidence_grounding.score,
|
| 73 |
e.actionability.score,
|
| 74 |
e.clarity.score,
|
| 75 |
+
e.safety_completeness.score,
|
| 76 |
]
|
| 77 |
|
| 78 |
offset = (idx - len(pareto_front) / 2) * width + width / 2
|
| 79 |
label = f"SOP v{entry['version']}"
|
| 80 |
ax1.bar(x + offset, scores, width, label=label, alpha=0.8)
|
| 81 |
|
| 82 |
+
ax1.set_xlabel("Metrics", fontsize=12)
|
| 83 |
+
ax1.set_ylabel("Score", fontsize=12)
|
| 84 |
+
ax1.set_title("5D Performance Comparison (Bar Chart)", fontsize=14)
|
| 85 |
ax1.set_xticks(x)
|
| 86 |
ax1.set_xticklabels(metrics, fontsize=10)
|
| 87 |
ax1.set_ylim(0, 1.0)
|
| 88 |
+
ax1.legend(loc="upper left")
|
| 89 |
+
ax1.grid(True, alpha=0.3, axis="y")
|
| 90 |
|
| 91 |
# --- Plot 2: Radar Chart ---
|
| 92 |
+
ax2 = plt.subplot(1, 2, 2, projection="polar")
|
| 93 |
|
| 94 |
+
categories = ["Clinical\nAccuracy", "Evidence\nGrounding", "Actionability", "Clarity", "Safety"]
|
|
|
|
| 95 |
num_vars = len(categories)
|
| 96 |
|
| 97 |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
| 98 |
angles += angles[:1]
|
| 99 |
|
| 100 |
for entry in pareto_front:
|
| 101 |
+
e = entry["evaluation"]
|
| 102 |
values = [
|
| 103 |
e.clinical_accuracy.score,
|
| 104 |
e.evidence_grounding.score,
|
| 105 |
e.actionability.score,
|
| 106 |
e.clarity.score,
|
| 107 |
+
e.safety_completeness.score,
|
| 108 |
]
|
| 109 |
values += values[:1]
|
| 110 |
|
| 111 |
+
desc = entry.get("description", "")[:30]
|
| 112 |
label = f"SOP v{entry['version']}: {desc}"
|
| 113 |
+
ax2.plot(angles, values, "o-", linewidth=2, label=label)
|
| 114 |
ax2.fill(angles, values, alpha=0.15)
|
| 115 |
|
| 116 |
ax2.set_xticks(angles[:-1])
|
| 117 |
ax2.set_xticklabels(categories, size=10)
|
| 118 |
ax2.set_ylim(0, 1)
|
| 119 |
+
ax2.set_title("5D Performance Profiles (Radar Chart)", size=14, y=1.08)
|
| 120 |
+
ax2.legend(loc="upper left", bbox_to_anchor=(1.2, 1.0), fontsize=9)
|
| 121 |
ax2.grid(True)
|
| 122 |
|
| 123 |
plt.tight_layout()
|
| 124 |
|
| 125 |
# Create data directory if it doesn't exist
|
| 126 |
from pathlib import Path
|
| 127 |
+
|
| 128 |
+
data_dir = Path("data")
|
| 129 |
data_dir.mkdir(exist_ok=True)
|
| 130 |
|
| 131 |
+
output_path = data_dir / "pareto_frontier_analysis.png"
|
| 132 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 133 |
plt.close()
|
| 134 |
|
| 135 |
print(f"\nβ Visualization saved to: {output_path}")
|
|
|
|
| 144 |
print(f"\nFound {len(pareto_front)} optimal (non-dominated) solutions:\n")
|
| 145 |
|
| 146 |
for entry in pareto_front:
|
| 147 |
+
v = entry["version"]
|
| 148 |
+
p = entry.get("parent")
|
| 149 |
+
desc = entry.get("description", "Baseline")
|
| 150 |
+
e = entry["evaluation"]
|
| 151 |
|
| 152 |
print(f"SOP v{v} {f'(Child of v{p})' if p else '(Baseline)'}")
|
| 153 |
print(f" Description: {desc}")
|
|
|
|
| 176 |
return
|
| 177 |
|
| 178 |
baseline = gene_pool_entries[0]
|
| 179 |
+
baseline_scores = np.array(baseline["evaluation"].to_vector())
|
| 180 |
|
| 181 |
print("\n" + "=" * 80)
|
| 182 |
print("IMPROVEMENT ANALYSIS")
|
|
|
|
| 187 |
|
| 188 |
improvements_found = False
|
| 189 |
for entry in gene_pool_entries[1:]:
|
| 190 |
+
scores = np.array(entry["evaluation"].to_vector())
|
| 191 |
avg_score = np.mean(scores)
|
| 192 |
baseline_avg = np.mean(baseline_scores)
|
| 193 |
|
|
|
|
| 199 |
print(f" Average Score: {avg_score:.3f} (+{improvement_pct:.1f}% vs baseline)")
|
| 200 |
|
| 201 |
# Show per-metric improvements
|
| 202 |
+
metric_names = [
|
| 203 |
+
"Clinical Accuracy",
|
| 204 |
+
"Evidence Grounding",
|
| 205 |
+
"Actionability",
|
| 206 |
+
"Clarity",
|
| 207 |
+
"Safety & Completeness",
|
| 208 |
+
]
|
| 209 |
for i, (name, score, baseline_score) in enumerate(zip(metric_names, scores, baseline_scores)):
|
| 210 |
diff = score - baseline_score
|
| 211 |
if abs(diff) > 0.01: # Show significant changes
|
archive/tests/test_evolution_loop.py
CHANGED
|
@@ -51,35 +51,27 @@ def create_test_patient() -> PatientInput:
|
|
| 51 |
"Sodium": 138.0,
|
| 52 |
"Potassium": 4.2,
|
| 53 |
"Chloride": 102.0,
|
| 54 |
-
"Bicarbonate": 24.0
|
| 55 |
}
|
| 56 |
|
| 57 |
model_prediction: dict[str, Any] = {
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
'Prediabetes': 0.05,
|
| 63 |
-
'Healthy': 0.03
|
| 64 |
-
},
|
| 65 |
-
'prediction_timestamp': '2025-01-01T10:00:00'
|
| 66 |
}
|
| 67 |
|
| 68 |
patient_context = {
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
}
|
| 77 |
|
| 78 |
-
return PatientInput(
|
| 79 |
-
biomarkers=biomarkers,
|
| 80 |
-
model_prediction=model_prediction,
|
| 81 |
-
patient_context=patient_context
|
| 82 |
-
)
|
| 83 |
|
| 84 |
|
| 85 |
def main():
|
|
@@ -101,36 +93,29 @@ def main():
|
|
| 101 |
# Run workflow with baseline SOP
|
| 102 |
|
| 103 |
initial_state: GuildState = {
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
}
|
| 116 |
|
| 117 |
guild_state = guild.workflow.invoke(initial_state)
|
| 118 |
|
| 119 |
-
baseline_response = guild_state[
|
| 120 |
-
agent_outputs = guild_state[
|
| 121 |
|
| 122 |
baseline_eval = run_full_evaluation(
|
| 123 |
-
final_response=baseline_response,
|
| 124 |
-
agent_outputs=agent_outputs,
|
| 125 |
-
biomarkers=patient.biomarkers
|
| 126 |
)
|
| 127 |
|
| 128 |
-
gene_pool.add(
|
| 129 |
-
sop=BASELINE_SOP,
|
| 130 |
-
evaluation=baseline_eval,
|
| 131 |
-
parent_version=None,
|
| 132 |
-
description="Baseline SOP"
|
| 133 |
-
)
|
| 134 |
|
| 135 |
print(f"\nβ Baseline Average Score: {baseline_eval.average_score():.3f}")
|
| 136 |
print(f" Clinical Accuracy: {baseline_eval.clinical_accuracy.score:.3f}")
|
|
@@ -152,16 +137,11 @@ def main():
|
|
| 152 |
# Create evaluation function for this cycle
|
| 153 |
def eval_func(final_response, agent_outputs, biomarkers):
|
| 154 |
return run_full_evaluation(
|
| 155 |
-
final_response=final_response,
|
| 156 |
-
agent_outputs=agent_outputs,
|
| 157 |
-
biomarkers=biomarkers
|
| 158 |
)
|
| 159 |
|
| 160 |
new_entries = run_evolution_cycle(
|
| 161 |
-
gene_pool=gene_pool,
|
| 162 |
-
patient_input=patient,
|
| 163 |
-
workflow_graph=guild.workflow,
|
| 164 |
-
evaluation_func=eval_func
|
| 165 |
)
|
| 166 |
|
| 167 |
print(f"\nβ Cycle {cycle} complete: Added {len(new_entries)} new SOPs to gene pool")
|
|
@@ -203,9 +183,9 @@ def main():
|
|
| 203 |
print(f"β Pareto Optimal SOPs: {len(pareto_front)}")
|
| 204 |
|
| 205 |
# Find best average score
|
| 206 |
-
best_sop = max(all_entries, key=lambda e: e[
|
| 207 |
baseline_avg = baseline_eval.average_score()
|
| 208 |
-
best_avg = best_sop[
|
| 209 |
improvement = ((best_avg - baseline_avg) / baseline_avg) * 100
|
| 210 |
|
| 211 |
print(f"\nBest SOP: v{best_sop['version']} - {best_sop['description']}")
|
|
|
|
| 51 |
"Sodium": 138.0,
|
| 52 |
"Potassium": 4.2,
|
| 53 |
"Chloride": 102.0,
|
| 54 |
+
"Bicarbonate": 24.0,
|
| 55 |
}
|
| 56 |
|
| 57 |
model_prediction: dict[str, Any] = {
|
| 58 |
+
"disease": "Type 2 Diabetes",
|
| 59 |
+
"confidence": 0.92,
|
| 60 |
+
"probabilities": {"Type 2 Diabetes": 0.92, "Prediabetes": 0.05, "Healthy": 0.03},
|
| 61 |
+
"prediction_timestamp": "2025-01-01T10:00:00",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
patient_context = {
|
| 65 |
+
"patient_id": "TEST-001",
|
| 66 |
+
"age": 55,
|
| 67 |
+
"gender": "male",
|
| 68 |
+
"symptoms": ["Increased thirst", "Frequent urination", "Fatigue"],
|
| 69 |
+
"medical_history": ["Prediabetes diagnosed 2 years ago"],
|
| 70 |
+
"current_medications": ["Metformin 500mg"],
|
| 71 |
+
"query": "My blood sugar has been high lately. What should I do?",
|
| 72 |
}
|
| 73 |
|
| 74 |
+
return PatientInput(biomarkers=biomarkers, model_prediction=model_prediction, patient_context=patient_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
def main():
|
|
|
|
| 93 |
# Run workflow with baseline SOP
|
| 94 |
|
| 95 |
initial_state: GuildState = {
|
| 96 |
+
"patient_biomarkers": patient.biomarkers,
|
| 97 |
+
"model_prediction": patient.model_prediction,
|
| 98 |
+
"patient_context": patient.patient_context,
|
| 99 |
+
"plan": None,
|
| 100 |
+
"sop": BASELINE_SOP,
|
| 101 |
+
"agent_outputs": [],
|
| 102 |
+
"biomarker_flags": [],
|
| 103 |
+
"safety_alerts": [],
|
| 104 |
+
"final_response": None,
|
| 105 |
+
"processing_timestamp": datetime.now().isoformat(),
|
| 106 |
+
"sop_version": "Baseline",
|
| 107 |
}
|
| 108 |
|
| 109 |
guild_state = guild.workflow.invoke(initial_state)
|
| 110 |
|
| 111 |
+
baseline_response = guild_state["final_response"]
|
| 112 |
+
agent_outputs = guild_state["agent_outputs"]
|
| 113 |
|
| 114 |
baseline_eval = run_full_evaluation(
|
| 115 |
+
final_response=baseline_response, agent_outputs=agent_outputs, biomarkers=patient.biomarkers
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
+
gene_pool.add(sop=BASELINE_SOP, evaluation=baseline_eval, parent_version=None, description="Baseline SOP")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
print(f"\nβ Baseline Average Score: {baseline_eval.average_score():.3f}")
|
| 121 |
print(f" Clinical Accuracy: {baseline_eval.clinical_accuracy.score:.3f}")
|
|
|
|
| 137 |
# Create evaluation function for this cycle
|
| 138 |
def eval_func(final_response, agent_outputs, biomarkers):
|
| 139 |
return run_full_evaluation(
|
| 140 |
+
final_response=final_response, agent_outputs=agent_outputs, biomarkers=biomarkers
|
|
|
|
|
|
|
| 141 |
)
|
| 142 |
|
| 143 |
new_entries = run_evolution_cycle(
|
| 144 |
+
gene_pool=gene_pool, patient_input=patient, workflow_graph=guild.workflow, evaluation_func=eval_func
|
|
|
|
|
|
|
|
|
|
| 145 |
)
|
| 146 |
|
| 147 |
print(f"\nβ Cycle {cycle} complete: Added {len(new_entries)} new SOPs to gene pool")
|
|
|
|
| 183 |
print(f"β Pareto Optimal SOPs: {len(pareto_front)}")
|
| 184 |
|
| 185 |
# Find best average score
|
| 186 |
+
best_sop = max(all_entries, key=lambda e: e["evaluation"].average_score())
|
| 187 |
baseline_avg = baseline_eval.average_score()
|
| 188 |
+
best_avg = best_sop["evaluation"].average_score()
|
| 189 |
improvement = ((best_avg - baseline_avg) / baseline_avg) * 100
|
| 190 |
|
| 191 |
print(f"\nBest SOP: v{best_sop['version']} - {best_sop['description']}")
|
archive/tests/test_evolution_quick.py
CHANGED
|
@@ -29,15 +29,10 @@ def main():
|
|
| 29 |
evidence_grounding=GradedScore(score=1.0, reasoning="Well cited"),
|
| 30 |
actionability=GradedScore(score=0.90, reasoning="Clear actions"),
|
| 31 |
clarity=GradedScore(score=0.75, reasoning="Could be clearer"),
|
| 32 |
-
safety_completeness=GradedScore(score=1.0, reasoning="Complete")
|
| 33 |
)
|
| 34 |
|
| 35 |
-
gene_pool.add(
|
| 36 |
-
sop=BASELINE_SOP,
|
| 37 |
-
evaluation=baseline_eval,
|
| 38 |
-
parent_version=None,
|
| 39 |
-
description="Baseline SOP"
|
| 40 |
-
)
|
| 41 |
|
| 42 |
print("β Gene pool initialized with 1 SOP")
|
| 43 |
print(f" Average score: {baseline_eval.average_score():.3f}")
|
|
|
|
| 29 |
evidence_grounding=GradedScore(score=1.0, reasoning="Well cited"),
|
| 30 |
actionability=GradedScore(score=0.90, reasoning="Clear actions"),
|
| 31 |
clarity=GradedScore(score=0.75, reasoning="Could be clearer"),
|
| 32 |
+
safety_completeness=GradedScore(score=1.0, reasoning="Complete"),
|
| 33 |
)
|
| 34 |
|
| 35 |
+
gene_pool.add(sop=BASELINE_SOP, evaluation=baseline_eval, parent_version=None, description="Baseline SOP")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
print("β Gene pool initialized with 1 SOP")
|
| 38 |
print(f" Average score: {baseline_eval.average_score():.3f}")
|
conftest.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Empty conftest to add the root project directory to pytest's sys.path
|
huggingface/app.py
CHANGED
|
@@ -11,16 +11,16 @@ Environment Variables (HuggingFace Secrets):
|
|
| 11 |
Required (pick one):
|
| 12 |
- GROQ_API_KEY: Groq API key (recommended, free)
|
| 13 |
- GOOGLE_API_KEY: Google Gemini API key (free)
|
| 14 |
-
|
| 15 |
Optional - LLM Configuration:
|
| 16 |
- LLM_PROVIDER: "groq" or "gemini" (auto-detected from keys)
|
| 17 |
- GROQ_MODEL: Model name (default: llama-3.3-70b-versatile)
|
| 18 |
- GEMINI_MODEL: Model name (default: gemini-2.0-flash)
|
| 19 |
-
|
| 20 |
Optional - Embeddings:
|
| 21 |
- EMBEDDING_PROVIDER: "jina", "google", or "huggingface" (default: huggingface)
|
| 22 |
- JINA_API_KEY: Jina AI API key for high-quality embeddings
|
| 23 |
-
|
| 24 |
Optional - Observability:
|
| 25 |
- LANGFUSE_ENABLED: "true" to enable tracing
|
| 26 |
- LANGFUSE_PUBLIC_KEY: Langfuse public key
|
|
@@ -57,6 +57,7 @@ logger = logging.getLogger("mediguard.huggingface")
|
|
| 57 |
# Configuration - Environment Variable Helpers
|
| 58 |
# ---------------------------------------------------------------------------
|
| 59 |
|
|
|
|
| 60 |
def _get_env(primary: str, *fallbacks, default: str = "") -> str:
|
| 61 |
"""Get env var with multiple fallback names for compatibility."""
|
| 62 |
value = os.getenv(primary)
|
|
@@ -71,7 +72,7 @@ def _get_env(primary: str, *fallbacks, default: str = "") -> str:
|
|
| 71 |
|
| 72 |
def get_api_keys():
|
| 73 |
"""Get API keys dynamically (HuggingFace injects secrets after module load).
|
| 74 |
-
|
| 75 |
Supports both simple and nested naming conventions:
|
| 76 |
- GROQ_API_KEY / LLM__GROQ_API_KEY
|
| 77 |
- GOOGLE_API_KEY / LLM__GOOGLE_API_KEY
|
|
@@ -109,7 +110,7 @@ def is_langfuse_enabled() -> bool:
|
|
| 109 |
|
| 110 |
def setup_llm_provider():
|
| 111 |
"""Set up LLM provider and related configuration based on available keys.
|
| 112 |
-
|
| 113 |
Sets environment variables for the entire application to use.
|
| 114 |
"""
|
| 115 |
groq_key, google_key = get_api_keys()
|
|
@@ -164,9 +165,7 @@ logger.info(f"EMBEDDING_PROVIDER: {get_embedding_provider()}")
|
|
| 164 |
logger.info(f"LANGFUSE: {'β enabled' if is_langfuse_enabled() else 'β disabled'}")
|
| 165 |
|
| 166 |
if not _groq and not _google:
|
| 167 |
-
logger.warning(
|
| 168 |
-
"No LLM API key found at startup. Will check again when analyzing."
|
| 169 |
-
)
|
| 170 |
else:
|
| 171 |
logger.info("LLM API key available β ready for analysis")
|
| 172 |
logger.info("=" * 60)
|
|
@@ -218,6 +217,7 @@ def get_guild():
|
|
| 218 |
start = time.time()
|
| 219 |
|
| 220 |
from src.workflow import create_guild
|
|
|
|
| 221 |
_guild = create_guild()
|
| 222 |
_guild_provider = current_provider
|
| 223 |
|
|
@@ -254,22 +254,29 @@ def auto_predict(biomarkers: dict[str, float]) -> dict[str, Any]:
|
|
| 254 |
def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, str, str]:
|
| 255 |
"""
|
| 256 |
Analyze biomarkers using the Clinical Insight Guild.
|
| 257 |
-
|
| 258 |
Returns: (summary, details_json, status)
|
| 259 |
"""
|
| 260 |
if not input_text.strip():
|
| 261 |
-
return
|
|
|
|
|
|
|
|
|
|
| 262 |
<div style="background: linear-gradient(135deg, #f0f4f8 0%, #e2e8f0 100%); border: 1px solid #cbd5e1; border-radius: 10px; padding: 16px; text-align: center;">
|
| 263 |
<span style="font-size: 2em;">βοΈ</span>
|
| 264 |
<p style="margin: 8px 0 0 0; color: #64748b;">Please enter biomarkers to analyze.</p>
|
| 265 |
</div>
|
| 266 |
-
"""
|
|
|
|
| 267 |
|
| 268 |
# Check API key dynamically (HF injects secrets after startup)
|
| 269 |
groq_key, google_key = get_api_keys()
|
| 270 |
|
| 271 |
if not groq_key and not google_key:
|
| 272 |
-
return
|
|
|
|
|
|
|
|
|
|
| 273 |
<div style="background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); border: 1px solid #ef4444; border-radius: 10px; padding: 16px;">
|
| 274 |
<strong style="color: #dc2626;">β No API Key Configured</strong>
|
| 275 |
<p style="margin: 12px 0 8px 0; color: #991b1b;">Please add your API key in Space Settings β Secrets:</p>
|
|
@@ -293,7 +300,8 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 293 |
</ul>
|
| 294 |
</details>
|
| 295 |
</div>
|
| 296 |
-
"""
|
|
|
|
| 297 |
|
| 298 |
# Setup provider based on available key
|
| 299 |
provider = setup_llm_provider()
|
|
@@ -304,7 +312,10 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 304 |
biomarkers = parse_biomarkers(input_text)
|
| 305 |
|
| 306 |
if not biomarkers:
|
| 307 |
-
return
|
|
|
|
|
|
|
|
|
|
| 308 |
<div style="background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); border: 1px solid #fbbf24; border-radius: 10px; padding: 16px;">
|
| 309 |
<strong>β οΈ Could not parse biomarkers</strong>
|
| 310 |
<p style="margin: 8px 0 0 0; color: #92400e;">Try formats like:</p>
|
|
@@ -313,7 +324,8 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 313 |
<li><code>{"Glucose": 140, "HbA1c": 7.5}</code></li>
|
| 314 |
</ul>
|
| 315 |
</div>
|
| 316 |
-
"""
|
|
|
|
| 317 |
|
| 318 |
progress(0.2, desc="π§ Initializing AI agents...")
|
| 319 |
|
|
@@ -329,7 +341,7 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 329 |
patient_input = PatientInput(
|
| 330 |
biomarkers=biomarkers,
|
| 331 |
model_prediction=prediction,
|
| 332 |
-
patient_context={"patient_id": "HF_User", "source": "huggingface_spaces"}
|
| 333 |
)
|
| 334 |
|
| 335 |
progress(0.4, desc="π€ Running Clinical Insight Guild...")
|
|
@@ -395,7 +407,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 395 |
"critical": ("π΄", "#dc2626", "#fef2f2"),
|
| 396 |
"high": ("π ", "#ea580c", "#fff7ed"),
|
| 397 |
"moderate": ("π‘", "#ca8a04", "#fefce8"),
|
| 398 |
-
"low": ("π’", "#16a34a", "#f0fdf4")
|
| 399 |
}
|
| 400 |
emoji, color, bg_color = severity_config.get(severity, severity_config["low"])
|
| 401 |
|
|
@@ -421,9 +433,11 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 421 |
alert_items = ""
|
| 422 |
for alert in alerts[:5]:
|
| 423 |
if isinstance(alert, dict):
|
| 424 |
-
alert_items +=
|
|
|
|
|
|
|
| 425 |
else:
|
| 426 |
-
alert_items += f
|
| 427 |
|
| 428 |
parts.append(f"""
|
| 429 |
<div style="background: linear-gradient(135deg, #fef2f2 0%, #fee2e2 100%); border: 1px solid #fecaca; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
|
@@ -463,7 +477,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 463 |
"high": ("π΄", "#dc2626", "#fef2f2"),
|
| 464 |
"abnormal": ("π‘", "#ca8a04", "#fefce8"),
|
| 465 |
"low": ("π‘", "#ca8a04", "#fefce8"),
|
| 466 |
-
"normal": ("π’", "#16a34a", "#f0fdf4")
|
| 467 |
}
|
| 468 |
s_emoji, s_color, s_bg = status_styles.get(status, status_styles["normal"])
|
| 469 |
|
|
@@ -549,7 +563,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 549 |
parts.append(f"""
|
| 550 |
<div style="background: #f8fafc; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
| 551 |
<h4 style="margin: 0 0 12px 0; color: #1e3a5f;">π Understanding Your Results</h4>
|
| 552 |
-
<p style="margin: 0; color: #475569; line-height: 1.6;">{pathophys[:600]}{
|
| 553 |
</div>
|
| 554 |
""")
|
| 555 |
|
|
@@ -659,14 +673,10 @@ Question: {question}
|
|
| 659 |
|
| 660 |
Answer:"""
|
| 661 |
response = llm.invoke(prompt)
|
| 662 |
-
return response.content if hasattr(response,
|
| 663 |
|
| 664 |
|
| 665 |
-
def answer_medical_question(
|
| 666 |
-
question: str,
|
| 667 |
-
context: str = "",
|
| 668 |
-
chat_history: list = None
|
| 669 |
-
) -> tuple[str, list]:
|
| 670 |
"""Answer a medical question using the full agentic RAG pipeline.
|
| 671 |
|
| 672 |
Pipeline: guardrail β retrieve β grade β rewrite β generate.
|
|
@@ -819,6 +829,7 @@ def hf_search(query: str, mode: str):
|
|
| 819 |
return "Please enter a query."
|
| 820 |
try:
|
| 821 |
from src.services.retrieval.factory import make_retriever
|
|
|
|
| 822 |
retriever = make_retriever()
|
| 823 |
docs = retriever.retrieve(query, top_k=5)
|
| 824 |
if not docs:
|
|
@@ -826,7 +837,7 @@ def hf_search(query: str, mode: str):
|
|
| 826 |
parts = []
|
| 827 |
for i, doc in enumerate(docs, 1):
|
| 828 |
title = doc.metadata.get("title", doc.metadata.get("source_file", "Untitled"))
|
| 829 |
-
score = doc.score if hasattr(doc,
|
| 830 |
parts.append(f"**[{i}] {title}** (score: {score:.3f})\n{doc.content}\n")
|
| 831 |
return "\n---\n".join(parts)
|
| 832 |
except Exception as exc:
|
|
@@ -1095,7 +1106,6 @@ def create_demo() -> gr.Blocks:
|
|
| 1095 |
),
|
| 1096 |
css=CUSTOM_CSS,
|
| 1097 |
) as demo:
|
| 1098 |
-
|
| 1099 |
# ===== HEADER =====
|
| 1100 |
gr.HTML("""
|
| 1101 |
<div class="header-container">
|
|
@@ -1129,13 +1139,10 @@ def create_demo() -> gr.Blocks:
|
|
| 1129 |
|
| 1130 |
# ===== MAIN TABS =====
|
| 1131 |
with gr.Tabs() as main_tabs:
|
| 1132 |
-
|
| 1133 |
# ==================== TAB 1: BIOMARKER ANALYSIS ====================
|
| 1134 |
with gr.Tab("π¬ Biomarker Analysis", id="biomarker-tab"):
|
| 1135 |
-
|
| 1136 |
# ===== MAIN CONTENT =====
|
| 1137 |
with gr.Row(equal_height=False):
|
| 1138 |
-
|
| 1139 |
# ----- LEFT PANEL: INPUT -----
|
| 1140 |
with gr.Column(scale=2, min_width=400):
|
| 1141 |
gr.HTML('<div class="section-title">π Enter Your Biomarkers</div>')
|
|
@@ -1143,7 +1150,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1143 |
with gr.Group():
|
| 1144 |
input_text = gr.Textbox(
|
| 1145 |
label="",
|
| 1146 |
-
placeholder=
|
| 1147 |
lines=6,
|
| 1148 |
max_lines=12,
|
| 1149 |
show_label=False,
|
|
@@ -1164,14 +1171,13 @@ def create_demo() -> gr.Blocks:
|
|
| 1164 |
)
|
| 1165 |
|
| 1166 |
# Status display
|
| 1167 |
-
status_output = gr.Markdown(
|
| 1168 |
-
value="",
|
| 1169 |
-
elem_classes="status-box"
|
| 1170 |
-
)
|
| 1171 |
|
| 1172 |
# Quick Examples
|
| 1173 |
gr.HTML('<div class="section-title" style="margin-top: 24px;">β‘ Quick Examples</div>')
|
| 1174 |
-
gr.HTML(
|
|
|
|
|
|
|
| 1175 |
|
| 1176 |
examples = gr.Examples(
|
| 1177 |
examples=[
|
|
@@ -1230,7 +1236,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1230 |
<p>Enter your biomarkers on the left and click <strong>Analyze</strong> to get your personalized health insights.</p>
|
| 1231 |
</div>
|
| 1232 |
""",
|
| 1233 |
-
elem_classes="summary-output"
|
| 1234 |
)
|
| 1235 |
|
| 1236 |
with gr.Tab("π Detailed JSON", id="json"):
|
|
@@ -1243,7 +1249,6 @@ def create_demo() -> gr.Blocks:
|
|
| 1243 |
|
| 1244 |
# ==================== TAB 2: MEDICAL Q&A ====================
|
| 1245 |
with gr.Tab("π¬ Medical Q&A", id="qa-tab"):
|
| 1246 |
-
|
| 1247 |
gr.HTML("""
|
| 1248 |
<div style="margin-bottom: 20px;">
|
| 1249 |
<h3 style="color: #1e3a5f; margin: 0 0 8px 0;">π¬ Medical Q&A Assistant</h3>
|
|
@@ -1264,7 +1269,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1264 |
qa_model = gr.Dropdown(
|
| 1265 |
choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
|
| 1266 |
value="llama-3.3-70b-versatile",
|
| 1267 |
-
label="LLM Provider/Model"
|
| 1268 |
)
|
| 1269 |
qa_question = gr.Textbox(
|
| 1270 |
label="Your Question",
|
|
@@ -1301,11 +1306,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1301 |
|
| 1302 |
with gr.Column(scale=2):
|
| 1303 |
gr.HTML('<h4 style="color: #1e3a5f; margin-bottom: 12px;">π Answer</h4>')
|
| 1304 |
-
qa_answer = gr.Chatbot(
|
| 1305 |
-
label="Medical Q&A History",
|
| 1306 |
-
height=600,
|
| 1307 |
-
elem_classes="qa-output"
|
| 1308 |
-
)
|
| 1309 |
|
| 1310 |
# Q&A Event Handlers
|
| 1311 |
qa_submit_btn.click(
|
|
@@ -1313,10 +1314,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1313 |
inputs=[qa_question, qa_context, qa_answer, qa_model],
|
| 1314 |
outputs=qa_answer,
|
| 1315 |
show_progress="minimal",
|
| 1316 |
-
).then(
|
| 1317 |
-
fn=lambda: "",
|
| 1318 |
-
outputs=qa_question
|
| 1319 |
-
)
|
| 1320 |
|
| 1321 |
qa_clear_btn.click(
|
| 1322 |
fn=lambda: ([], ""),
|
|
@@ -1327,16 +1325,10 @@ def create_demo() -> gr.Blocks:
|
|
| 1327 |
with gr.Tab("π Search Knowledge Base", id="search-tab"):
|
| 1328 |
with gr.Row():
|
| 1329 |
search_input = gr.Textbox(
|
| 1330 |
-
label="Search Query",
|
| 1331 |
-
placeholder="e.g., diabetes management guidelines",
|
| 1332 |
-
lines=2,
|
| 1333 |
-
scale=3
|
| 1334 |
)
|
| 1335 |
search_mode = gr.Radio(
|
| 1336 |
-
choices=["hybrid", "bm25", "vector"],
|
| 1337 |
-
value="hybrid",
|
| 1338 |
-
label="Search Strategy",
|
| 1339 |
-
scale=1
|
| 1340 |
)
|
| 1341 |
search_btn = gr.Button("Search", variant="primary")
|
| 1342 |
search_output = gr.Textbox(label="Results", lines=20, interactive=False)
|
|
@@ -1409,13 +1401,18 @@ def create_demo() -> gr.Blocks:
|
|
| 1409 |
)
|
| 1410 |
|
| 1411 |
clear_btn.click(
|
| 1412 |
-
fn=lambda: (
|
|
|
|
|
|
|
| 1413 |
<div style="text-align: center; padding: 60px 20px; color: #94a3b8;">
|
| 1414 |
<div style="font-size: 4em; margin-bottom: 16px;">π¬</div>
|
| 1415 |
<h3 style="color: #64748b; font-weight: 500;">Ready to Analyze</h3>
|
| 1416 |
<p>Enter your biomarkers on the left and click <strong>Analyze</strong> to get your personalized health insights.</p>
|
| 1417 |
</div>
|
| 1418 |
-
""",
|
|
|
|
|
|
|
|
|
|
| 1419 |
outputs=[input_text, summary_output, details_output, status_output],
|
| 1420 |
)
|
| 1421 |
|
|
|
|
| 11 |
Required (pick one):
|
| 12 |
- GROQ_API_KEY: Groq API key (recommended, free)
|
| 13 |
- GOOGLE_API_KEY: Google Gemini API key (free)
|
| 14 |
+
|
| 15 |
Optional - LLM Configuration:
|
| 16 |
- LLM_PROVIDER: "groq" or "gemini" (auto-detected from keys)
|
| 17 |
- GROQ_MODEL: Model name (default: llama-3.3-70b-versatile)
|
| 18 |
- GEMINI_MODEL: Model name (default: gemini-2.0-flash)
|
| 19 |
+
|
| 20 |
Optional - Embeddings:
|
| 21 |
- EMBEDDING_PROVIDER: "jina", "google", or "huggingface" (default: huggingface)
|
| 22 |
- JINA_API_KEY: Jina AI API key for high-quality embeddings
|
| 23 |
+
|
| 24 |
Optional - Observability:
|
| 25 |
- LANGFUSE_ENABLED: "true" to enable tracing
|
| 26 |
- LANGFUSE_PUBLIC_KEY: Langfuse public key
|
|
|
|
| 57 |
# Configuration - Environment Variable Helpers
|
| 58 |
# ---------------------------------------------------------------------------
|
| 59 |
|
| 60 |
+
|
| 61 |
def _get_env(primary: str, *fallbacks, default: str = "") -> str:
|
| 62 |
"""Get env var with multiple fallback names for compatibility."""
|
| 63 |
value = os.getenv(primary)
|
|
|
|
| 72 |
|
| 73 |
def get_api_keys():
|
| 74 |
"""Get API keys dynamically (HuggingFace injects secrets after module load).
|
| 75 |
+
|
| 76 |
Supports both simple and nested naming conventions:
|
| 77 |
- GROQ_API_KEY / LLM__GROQ_API_KEY
|
| 78 |
- GOOGLE_API_KEY / LLM__GOOGLE_API_KEY
|
|
|
|
| 110 |
|
| 111 |
def setup_llm_provider():
|
| 112 |
"""Set up LLM provider and related configuration based on available keys.
|
| 113 |
+
|
| 114 |
Sets environment variables for the entire application to use.
|
| 115 |
"""
|
| 116 |
groq_key, google_key = get_api_keys()
|
|
|
|
| 165 |
logger.info(f"LANGFUSE: {'β enabled' if is_langfuse_enabled() else 'β disabled'}")
|
| 166 |
|
| 167 |
if not _groq and not _google:
|
| 168 |
+
logger.warning("No LLM API key found at startup. Will check again when analyzing.")
|
|
|
|
|
|
|
| 169 |
else:
|
| 170 |
logger.info("LLM API key available β ready for analysis")
|
| 171 |
logger.info("=" * 60)
|
|
|
|
| 217 |
start = time.time()
|
| 218 |
|
| 219 |
from src.workflow import create_guild
|
| 220 |
+
|
| 221 |
_guild = create_guild()
|
| 222 |
_guild_provider = current_provider
|
| 223 |
|
|
|
|
| 254 |
def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, str, str]:
|
| 255 |
"""
|
| 256 |
Analyze biomarkers using the Clinical Insight Guild.
|
| 257 |
+
|
| 258 |
Returns: (summary, details_json, status)
|
| 259 |
"""
|
| 260 |
if not input_text.strip():
|
| 261 |
+
return (
|
| 262 |
+
"",
|
| 263 |
+
"",
|
| 264 |
+
"""
|
| 265 |
<div style="background: linear-gradient(135deg, #f0f4f8 0%, #e2e8f0 100%); border: 1px solid #cbd5e1; border-radius: 10px; padding: 16px; text-align: center;">
|
| 266 |
<span style="font-size: 2em;">βοΈ</span>
|
| 267 |
<p style="margin: 8px 0 0 0; color: #64748b;">Please enter biomarkers to analyze.</p>
|
| 268 |
</div>
|
| 269 |
+
""",
|
| 270 |
+
)
|
| 271 |
|
| 272 |
# Check API key dynamically (HF injects secrets after startup)
|
| 273 |
groq_key, google_key = get_api_keys()
|
| 274 |
|
| 275 |
if not groq_key and not google_key:
|
| 276 |
+
return (
|
| 277 |
+
"",
|
| 278 |
+
"",
|
| 279 |
+
"""
|
| 280 |
<div style="background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); border: 1px solid #ef4444; border-radius: 10px; padding: 16px;">
|
| 281 |
<strong style="color: #dc2626;">β No API Key Configured</strong>
|
| 282 |
<p style="margin: 12px 0 8px 0; color: #991b1b;">Please add your API key in Space Settings β Secrets:</p>
|
|
|
|
| 300 |
</ul>
|
| 301 |
</details>
|
| 302 |
</div>
|
| 303 |
+
""",
|
| 304 |
+
)
|
| 305 |
|
| 306 |
# Setup provider based on available key
|
| 307 |
provider = setup_llm_provider()
|
|
|
|
| 312 |
biomarkers = parse_biomarkers(input_text)
|
| 313 |
|
| 314 |
if not biomarkers:
|
| 315 |
+
return (
|
| 316 |
+
"",
|
| 317 |
+
"",
|
| 318 |
+
"""
|
| 319 |
<div style="background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); border: 1px solid #fbbf24; border-radius: 10px; padding: 16px;">
|
| 320 |
<strong>β οΈ Could not parse biomarkers</strong>
|
| 321 |
<p style="margin: 8px 0 0 0; color: #92400e;">Try formats like:</p>
|
|
|
|
| 324 |
<li><code>{"Glucose": 140, "HbA1c": 7.5}</code></li>
|
| 325 |
</ul>
|
| 326 |
</div>
|
| 327 |
+
""",
|
| 328 |
+
)
|
| 329 |
|
| 330 |
progress(0.2, desc="π§ Initializing AI agents...")
|
| 331 |
|
|
|
|
| 341 |
patient_input = PatientInput(
|
| 342 |
biomarkers=biomarkers,
|
| 343 |
model_prediction=prediction,
|
| 344 |
+
patient_context={"patient_id": "HF_User", "source": "huggingface_spaces"},
|
| 345 |
)
|
| 346 |
|
| 347 |
progress(0.4, desc="π€ Running Clinical Insight Guild...")
|
|
|
|
| 407 |
"critical": ("π΄", "#dc2626", "#fef2f2"),
|
| 408 |
"high": ("π ", "#ea580c", "#fff7ed"),
|
| 409 |
"moderate": ("π‘", "#ca8a04", "#fefce8"),
|
| 410 |
+
"low": ("π’", "#16a34a", "#f0fdf4"),
|
| 411 |
}
|
| 412 |
emoji, color, bg_color = severity_config.get(severity, severity_config["low"])
|
| 413 |
|
|
|
|
| 433 |
alert_items = ""
|
| 434 |
for alert in alerts[:5]:
|
| 435 |
if isinstance(alert, dict):
|
| 436 |
+
alert_items += (
|
| 437 |
+
f"<li><strong>{alert.get('alert_type', 'Alert')}:</strong> {alert.get('message', '')}</li>"
|
| 438 |
+
)
|
| 439 |
else:
|
| 440 |
+
alert_items += f"<li>{alert}</li>"
|
| 441 |
|
| 442 |
parts.append(f"""
|
| 443 |
<div style="background: linear-gradient(135deg, #fef2f2 0%, #fee2e2 100%); border: 1px solid #fecaca; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
|
|
|
| 477 |
"high": ("π΄", "#dc2626", "#fef2f2"),
|
| 478 |
"abnormal": ("π‘", "#ca8a04", "#fefce8"),
|
| 479 |
"low": ("π‘", "#ca8a04", "#fefce8"),
|
| 480 |
+
"normal": ("π’", "#16a34a", "#f0fdf4"),
|
| 481 |
}
|
| 482 |
s_emoji, s_color, s_bg = status_styles.get(status, status_styles["normal"])
|
| 483 |
|
|
|
|
| 563 |
parts.append(f"""
|
| 564 |
<div style="background: #f8fafc; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
| 565 |
<h4 style="margin: 0 0 12px 0; color: #1e3a5f;">π Understanding Your Results</h4>
|
| 566 |
+
<p style="margin: 0; color: #475569; line-height: 1.6;">{pathophys[:600]}{"..." if len(pathophys) > 600 else ""}</p>
|
| 567 |
</div>
|
| 568 |
""")
|
| 569 |
|
|
|
|
| 673 |
|
| 674 |
Answer:"""
|
| 675 |
response = llm.invoke(prompt)
|
| 676 |
+
return response.content if hasattr(response, "content") else str(response)
|
| 677 |
|
| 678 |
|
| 679 |
+
def answer_medical_question(question: str, context: str = "", chat_history: list | None = None) -> tuple[str, list]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
"""Answer a medical question using the full agentic RAG pipeline.
|
| 681 |
|
| 682 |
Pipeline: guardrail β retrieve β grade β rewrite β generate.
|
|
|
|
| 829 |
return "Please enter a query."
|
| 830 |
try:
|
| 831 |
from src.services.retrieval.factory import make_retriever
|
| 832 |
+
|
| 833 |
retriever = make_retriever()
|
| 834 |
docs = retriever.retrieve(query, top_k=5)
|
| 835 |
if not docs:
|
|
|
|
| 837 |
parts = []
|
| 838 |
for i, doc in enumerate(docs, 1):
|
| 839 |
title = doc.metadata.get("title", doc.metadata.get("source_file", "Untitled"))
|
| 840 |
+
score = doc.score if hasattr(doc, "score") else 0.0
|
| 841 |
parts.append(f"**[{i}] {title}** (score: {score:.3f})\n{doc.content}\n")
|
| 842 |
return "\n---\n".join(parts)
|
| 843 |
except Exception as exc:
|
|
|
|
| 1106 |
),
|
| 1107 |
css=CUSTOM_CSS,
|
| 1108 |
) as demo:
|
|
|
|
| 1109 |
# ===== HEADER =====
|
| 1110 |
gr.HTML("""
|
| 1111 |
<div class="header-container">
|
|
|
|
| 1139 |
|
| 1140 |
# ===== MAIN TABS =====
|
| 1141 |
with gr.Tabs() as main_tabs:
|
|
|
|
| 1142 |
# ==================== TAB 1: BIOMARKER ANALYSIS ====================
|
| 1143 |
with gr.Tab("π¬ Biomarker Analysis", id="biomarker-tab"):
|
|
|
|
| 1144 |
# ===== MAIN CONTENT =====
|
| 1145 |
with gr.Row(equal_height=False):
|
|
|
|
| 1146 |
# ----- LEFT PANEL: INPUT -----
|
| 1147 |
with gr.Column(scale=2, min_width=400):
|
| 1148 |
gr.HTML('<div class="section-title">π Enter Your Biomarkers</div>')
|
|
|
|
| 1150 |
with gr.Group():
|
| 1151 |
input_text = gr.Textbox(
|
| 1152 |
label="",
|
| 1153 |
+
placeholder='Enter biomarkers in any format:\n\nβ’ Glucose: 140, HbA1c: 7.5, Cholesterol: 210\nβ’ My glucose is 140 and HbA1c is 7.5\nβ’ {"Glucose": 140, "HbA1c": 7.5}',
|
| 1154 |
lines=6,
|
| 1155 |
max_lines=12,
|
| 1156 |
show_label=False,
|
|
|
|
| 1171 |
)
|
| 1172 |
|
| 1173 |
# Status display
|
| 1174 |
+
status_output = gr.Markdown(value="", elem_classes="status-box")
|
|
|
|
|
|
|
|
|
|
| 1175 |
|
| 1176 |
# Quick Examples
|
| 1177 |
gr.HTML('<div class="section-title" style="margin-top: 24px;">β‘ Quick Examples</div>')
|
| 1178 |
+
gr.HTML(
|
| 1179 |
+
'<p style="color: #64748b; font-size: 0.9em; margin-bottom: 12px;">Click any example to load it instantly</p>'
|
| 1180 |
+
)
|
| 1181 |
|
| 1182 |
examples = gr.Examples(
|
| 1183 |
examples=[
|
|
|
|
| 1236 |
<p>Enter your biomarkers on the left and click <strong>Analyze</strong> to get your personalized health insights.</p>
|
| 1237 |
</div>
|
| 1238 |
""",
|
| 1239 |
+
elem_classes="summary-output",
|
| 1240 |
)
|
| 1241 |
|
| 1242 |
with gr.Tab("π Detailed JSON", id="json"):
|
|
|
|
| 1249 |
|
| 1250 |
# ==================== TAB 2: MEDICAL Q&A ====================
|
| 1251 |
with gr.Tab("π¬ Medical Q&A", id="qa-tab"):
|
|
|
|
| 1252 |
gr.HTML("""
|
| 1253 |
<div style="margin-bottom: 20px;">
|
| 1254 |
<h3 style="color: #1e3a5f; margin: 0 0 8px 0;">π¬ Medical Q&A Assistant</h3>
|
|
|
|
| 1269 |
qa_model = gr.Dropdown(
|
| 1270 |
choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
|
| 1271 |
value="llama-3.3-70b-versatile",
|
| 1272 |
+
label="LLM Provider/Model",
|
| 1273 |
)
|
| 1274 |
qa_question = gr.Textbox(
|
| 1275 |
label="Your Question",
|
|
|
|
| 1306 |
|
| 1307 |
with gr.Column(scale=2):
|
| 1308 |
gr.HTML('<h4 style="color: #1e3a5f; margin-bottom: 12px;">π Answer</h4>')
|
| 1309 |
+
qa_answer = gr.Chatbot(label="Medical Q&A History", height=600, elem_classes="qa-output")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1310 |
|
| 1311 |
# Q&A Event Handlers
|
| 1312 |
qa_submit_btn.click(
|
|
|
|
| 1314 |
inputs=[qa_question, qa_context, qa_answer, qa_model],
|
| 1315 |
outputs=qa_answer,
|
| 1316 |
show_progress="minimal",
|
| 1317 |
+
).then(fn=lambda: "", outputs=qa_question)
|
|
|
|
|
|
|
|
|
|
| 1318 |
|
| 1319 |
qa_clear_btn.click(
|
| 1320 |
fn=lambda: ([], ""),
|
|
|
|
| 1325 |
with gr.Tab("π Search Knowledge Base", id="search-tab"):
|
| 1326 |
with gr.Row():
|
| 1327 |
search_input = gr.Textbox(
|
| 1328 |
+
label="Search Query", placeholder="e.g., diabetes management guidelines", lines=2, scale=3
|
|
|
|
|
|
|
|
|
|
| 1329 |
)
|
| 1330 |
search_mode = gr.Radio(
|
| 1331 |
+
choices=["hybrid", "bm25", "vector"], value="hybrid", label="Search Strategy", scale=1
|
|
|
|
|
|
|
|
|
|
| 1332 |
)
|
| 1333 |
search_btn = gr.Button("Search", variant="primary")
|
| 1334 |
search_output = gr.Textbox(label="Results", lines=20, interactive=False)
|
|
|
|
| 1401 |
)
|
| 1402 |
|
| 1403 |
clear_btn.click(
|
| 1404 |
+
fn=lambda: (
|
| 1405 |
+
"",
|
| 1406 |
+
"""
|
| 1407 |
<div style="text-align: center; padding: 60px 20px; color: #94a3b8;">
|
| 1408 |
<div style="font-size: 4em; margin-bottom: 16px;">π¬</div>
|
| 1409 |
<h3 style="color: #64748b; font-weight: 500;">Ready to Analyze</h3>
|
| 1410 |
<p>Enter your biomarkers on the left and click <strong>Analyze</strong> to get your personalized health insights.</p>
|
| 1411 |
</div>
|
| 1412 |
+
""",
|
| 1413 |
+
"",
|
| 1414 |
+
"",
|
| 1415 |
+
),
|
| 1416 |
outputs=[input_text, summary_output, details_output, status_output],
|
| 1417 |
)
|
| 1418 |
|
pytest.ini
CHANGED
|
@@ -5,3 +5,5 @@ filterwarnings =
|
|
| 5 |
|
| 6 |
markers =
|
| 7 |
integration: mark a test as an integration test.
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
markers =
|
| 7 |
integration: mark a test as an integration test.
|
| 8 |
+
|
| 9 |
+
testpaths = tests
|
scripts/chat.py
CHANGED
|
@@ -26,15 +26,16 @@ from pathlib import Path
|
|
| 26 |
from typing import Any
|
| 27 |
|
| 28 |
# Set UTF-8 encoding for Windows console
|
| 29 |
-
if sys.platform ==
|
| 30 |
try:
|
| 31 |
-
sys.stdout.reconfigure(encoding=
|
| 32 |
-
sys.stderr.reconfigure(encoding=
|
| 33 |
except Exception:
|
| 34 |
import codecs
|
| 35 |
-
|
| 36 |
-
sys.
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
# Add parent directory to path for imports
|
| 40 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
@@ -82,6 +83,7 @@ If you cannot find any biomarkers, return {{"biomarkers": {{}}, "patient_context
|
|
| 82 |
# Component 1: Biomarker Extraction
|
| 83 |
# ============================================================================
|
| 84 |
|
|
|
|
| 85 |
def _parse_llm_json(content: str) -> dict[str, Any]:
|
| 86 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 87 |
text = content.strip()
|
|
@@ -97,14 +99,14 @@ def _parse_llm_json(content: str) -> dict[str, Any]:
|
|
| 97 |
left = text.find("{")
|
| 98 |
right = text.rfind("}")
|
| 99 |
if left != -1 and right != -1 and right > left:
|
| 100 |
-
return json.loads(text[left:right + 1])
|
| 101 |
raise
|
| 102 |
|
| 103 |
|
| 104 |
def extract_biomarkers(user_message: str) -> tuple[dict[str, float], dict[str, Any]]:
|
| 105 |
"""
|
| 106 |
Extract biomarker values from natural language using LLM.
|
| 107 |
-
|
| 108 |
Returns:
|
| 109 |
Tuple of (biomarkers_dict, patient_context_dict)
|
| 110 |
"""
|
|
@@ -140,6 +142,7 @@ def extract_biomarkers(user_message: str) -> tuple[dict[str, float], dict[str, A
|
|
| 140 |
except Exception as e:
|
| 141 |
print(f"β οΈ Extraction failed: {e}")
|
| 142 |
import traceback
|
|
|
|
| 143 |
traceback.print_exc()
|
| 144 |
return {}, {}
|
| 145 |
|
|
@@ -148,17 +151,12 @@ def extract_biomarkers(user_message: str) -> tuple[dict[str, float], dict[str, A
|
|
| 148 |
# Component 2: Disease Prediction
|
| 149 |
# ============================================================================
|
| 150 |
|
|
|
|
| 151 |
def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 152 |
"""
|
| 153 |
Simple rule-based disease prediction based on key biomarkers.
|
| 154 |
"""
|
| 155 |
-
scores = {
|
| 156 |
-
"Diabetes": 0.0,
|
| 157 |
-
"Anemia": 0.0,
|
| 158 |
-
"Heart Disease": 0.0,
|
| 159 |
-
"Thrombocytopenia": 0.0,
|
| 160 |
-
"Thalassemia": 0.0
|
| 161 |
-
}
|
| 162 |
|
| 163 |
# Helper: check both abbreviated and normalized biomarker names
|
| 164 |
# Returns None when biomarker is not present (avoids false triggers)
|
|
@@ -228,11 +226,7 @@ def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
|
| 228 |
else:
|
| 229 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 230 |
|
| 231 |
-
return {
|
| 232 |
-
"disease": top_disease,
|
| 233 |
-
"confidence": confidence,
|
| 234 |
-
"probabilities": probabilities
|
| 235 |
-
}
|
| 236 |
|
| 237 |
|
| 238 |
def predict_disease_llm(biomarkers: dict[str, float], patient_context: dict) -> dict[str, Any]:
|
|
@@ -280,6 +274,7 @@ Return ONLY valid JSON (no other text):
|
|
| 280 |
except Exception as e:
|
| 281 |
print(f"β οΈ LLM prediction failed ({e}), using rule-based fallback")
|
| 282 |
import traceback
|
|
|
|
| 283 |
traceback.print_exc()
|
| 284 |
return predict_disease_simple(biomarkers)
|
| 285 |
|
|
@@ -288,6 +283,7 @@ Return ONLY valid JSON (no other text):
|
|
| 288 |
# Component 3: Conversational Formatter
|
| 289 |
# ============================================================================
|
| 290 |
|
|
|
|
| 291 |
def _coerce_to_dict(obj) -> dict:
|
| 292 |
"""Convert a Pydantic model or arbitrary object to a plain dict."""
|
| 293 |
if isinstance(obj, dict):
|
|
@@ -379,6 +375,7 @@ def format_conversational(result: dict[str, Any], user_name: str = "there") -> s
|
|
| 379 |
# Component 4: Helper Functions
|
| 380 |
# ============================================================================
|
| 381 |
|
|
|
|
| 382 |
def print_biomarker_help():
|
| 383 |
"""Print list of supported biomarkers"""
|
| 384 |
print("\nπ Supported Biomarkers (24 total):")
|
|
@@ -409,7 +406,7 @@ def run_example_case(guild):
|
|
| 409 |
"Platelets": 220000,
|
| 410 |
"White Blood Cells": 7500,
|
| 411 |
"Systolic Blood Pressure": 145,
|
| 412 |
-
"Diastolic Blood Pressure": 92
|
| 413 |
}
|
| 414 |
|
| 415 |
prediction = {
|
|
@@ -420,25 +417,25 @@ def run_example_case(guild):
|
|
| 420 |
"Heart Disease": 0.08,
|
| 421 |
"Anemia": 0.03,
|
| 422 |
"Thrombocytopenia": 0.01,
|
| 423 |
-
"Thalassemia": 0.01
|
| 424 |
-
}
|
| 425 |
}
|
| 426 |
|
| 427 |
patient_input = PatientInput(
|
| 428 |
biomarkers=example_biomarkers,
|
| 429 |
model_prediction=prediction,
|
| 430 |
-
patient_context={"age": 52, "gender": "male", "bmi": 31.2}
|
| 431 |
)
|
| 432 |
|
| 433 |
print("π Running analysis...\n")
|
| 434 |
result = guild.run(patient_input)
|
| 435 |
|
| 436 |
response = format_conversational(result.get("final_response", result), "there")
|
| 437 |
-
print("\n" + "="*70)
|
| 438 |
print("π€ RAG-BOT:")
|
| 439 |
-
print("="*70)
|
| 440 |
print(response)
|
| 441 |
-
print("="*70 + "\n")
|
| 442 |
|
| 443 |
|
| 444 |
def save_report(result: dict, biomarkers: dict):
|
|
@@ -447,11 +444,10 @@ def save_report(result: dict, biomarkers: dict):
|
|
| 447 |
|
| 448 |
# final_response is already a plain dict built by the synthesizer
|
| 449 |
final = result.get("final_response") or {}
|
| 450 |
-
disease = (
|
| 451 |
-
|
| 452 |
-
or result.get("model_prediction", {}).get("disease", "unknown")
|
| 453 |
)
|
| 454 |
-
disease_safe = disease.replace(
|
| 455 |
filename = f"report_{disease_safe}_{timestamp}.json"
|
| 456 |
|
| 457 |
output_dir = Path("data/chat_reports")
|
|
@@ -465,9 +461,9 @@ def save_report(result: dict, biomarkers: dict):
|
|
| 465 |
return {k: _to_dict(v) for k, v in obj.items()}
|
| 466 |
if isinstance(obj, list):
|
| 467 |
return [_to_dict(i) for i in obj]
|
| 468 |
-
if hasattr(obj, "model_dump"):
|
| 469 |
return _to_dict(obj.model_dump())
|
| 470 |
-
if hasattr(obj, "dict"):
|
| 471 |
return _to_dict(obj.dict())
|
| 472 |
# Scalars and other primitives are returned as-is
|
| 473 |
return obj
|
|
@@ -480,7 +476,7 @@ def save_report(result: dict, biomarkers: dict):
|
|
| 480 |
"safety_alerts": _to_dict(result.get("safety_alerts", [])),
|
| 481 |
}
|
| 482 |
|
| 483 |
-
with open(filepath,
|
| 484 |
json.dump(report, f, indent=2)
|
| 485 |
|
| 486 |
print(f"β
Report saved to: {filepath}\n")
|
|
@@ -490,21 +486,22 @@ def save_report(result: dict, biomarkers: dict):
|
|
| 490 |
# Main Chat Interface
|
| 491 |
# ============================================================================
|
| 492 |
|
|
|
|
| 493 |
def chat_interface():
|
| 494 |
"""
|
| 495 |
Main interactive CLI chatbot for MediGuard AI RAG-Helper.
|
| 496 |
"""
|
| 497 |
# Print welcome banner
|
| 498 |
-
print("\n" + "="*70)
|
| 499 |
print("π€ MediGuard AI RAG-Helper - Interactive Chat")
|
| 500 |
-
print("="*70)
|
| 501 |
print("\nWelcome! I can help you understand your blood test results.\n")
|
| 502 |
print("You can:")
|
| 503 |
print(" 1. Describe your biomarkers (e.g., 'My glucose is 140, HbA1c is 7.5')")
|
| 504 |
print(" 2. Type 'example' to see a sample diabetes case")
|
| 505 |
print(" 3. Type 'help' for biomarker list")
|
| 506 |
print(" 4. Type 'quit' to exit\n")
|
| 507 |
-
print("="*70 + "\n")
|
| 508 |
|
| 509 |
# Initialize guild (one-time setup)
|
| 510 |
print("π§ Initializing medical knowledge system...")
|
|
@@ -532,15 +529,15 @@ def chat_interface():
|
|
| 532 |
continue
|
| 533 |
|
| 534 |
# Handle special commands
|
| 535 |
-
if user_input.lower() in [
|
| 536 |
print("\nπ Thank you for using MediGuard AI. Stay healthy!")
|
| 537 |
break
|
| 538 |
|
| 539 |
-
if user_input.lower() ==
|
| 540 |
print_biomarker_help()
|
| 541 |
continue
|
| 542 |
|
| 543 |
-
if user_input.lower() ==
|
| 544 |
run_example_case(guild)
|
| 545 |
continue
|
| 546 |
|
|
@@ -571,7 +568,7 @@ def chat_interface():
|
|
| 571 |
patient_input = PatientInput(
|
| 572 |
biomarkers=biomarkers,
|
| 573 |
model_prediction=prediction,
|
| 574 |
-
patient_context=patient_context if patient_context else {"source": "chat"}
|
| 575 |
)
|
| 576 |
|
| 577 |
# Run full RAG workflow
|
|
@@ -584,23 +581,20 @@ def chat_interface():
|
|
| 584 |
response = format_conversational(result.get("final_response", result), user_name)
|
| 585 |
|
| 586 |
# Display response
|
| 587 |
-
print("\n" + "="*70)
|
| 588 |
print("π€ RAG-BOT:")
|
| 589 |
-
print("="*70)
|
| 590 |
print(response)
|
| 591 |
-
print("="*70 + "\n")
|
| 592 |
|
| 593 |
# Save to history
|
| 594 |
-
conversation_history.append(
|
| 595 |
-
"user_input": user_input,
|
| 596 |
-
|
| 597 |
-
"prediction": prediction,
|
| 598 |
-
"result": result
|
| 599 |
-
})
|
| 600 |
|
| 601 |
# Ask if user wants to save report
|
| 602 |
save_choice = input("πΎ Save detailed report to file? (y/n): ").strip().lower()
|
| 603 |
-
if save_choice ==
|
| 604 |
save_report(result, biomarkers)
|
| 605 |
|
| 606 |
print("\nYou can:")
|
|
@@ -612,6 +606,7 @@ def chat_interface():
|
|
| 612 |
break
|
| 613 |
except Exception as e:
|
| 614 |
import traceback
|
|
|
|
| 615 |
traceback.print_exc()
|
| 616 |
print(f"\nβ Analysis failed: {e}")
|
| 617 |
print("\nThis might be due to:")
|
|
|
|
| 26 |
from typing import Any
|
| 27 |
|
| 28 |
# Set UTF-8 encoding for Windows console
|
| 29 |
+
if sys.platform == "win32":
|
| 30 |
try:
|
| 31 |
+
sys.stdout.reconfigure(encoding="utf-8")
|
| 32 |
+
sys.stderr.reconfigure(encoding="utf-8")
|
| 33 |
except Exception:
|
| 34 |
import codecs
|
| 35 |
+
|
| 36 |
+
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer, "strict")
|
| 37 |
+
sys.stderr = codecs.getwriter("utf-8")(sys.stderr.buffer, "strict")
|
| 38 |
+
os.system("chcp 65001 > nul 2>&1")
|
| 39 |
|
| 40 |
# Add parent directory to path for imports
|
| 41 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
| 83 |
# Component 1: Biomarker Extraction
|
| 84 |
# ============================================================================
|
| 85 |
|
| 86 |
+
|
| 87 |
def _parse_llm_json(content: str) -> dict[str, Any]:
|
| 88 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 89 |
text = content.strip()
|
|
|
|
| 99 |
left = text.find("{")
|
| 100 |
right = text.rfind("}")
|
| 101 |
if left != -1 and right != -1 and right > left:
|
| 102 |
+
return json.loads(text[left : right + 1])
|
| 103 |
raise
|
| 104 |
|
| 105 |
|
| 106 |
def extract_biomarkers(user_message: str) -> tuple[dict[str, float], dict[str, Any]]:
|
| 107 |
"""
|
| 108 |
Extract biomarker values from natural language using LLM.
|
| 109 |
+
|
| 110 |
Returns:
|
| 111 |
Tuple of (biomarkers_dict, patient_context_dict)
|
| 112 |
"""
|
|
|
|
| 142 |
except Exception as e:
|
| 143 |
print(f"β οΈ Extraction failed: {e}")
|
| 144 |
import traceback
|
| 145 |
+
|
| 146 |
traceback.print_exc()
|
| 147 |
return {}, {}
|
| 148 |
|
|
|
|
| 151 |
# Component 2: Disease Prediction
|
| 152 |
# ============================================================================
|
| 153 |
|
| 154 |
+
|
| 155 |
def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 156 |
"""
|
| 157 |
Simple rule-based disease prediction based on key biomarkers.
|
| 158 |
"""
|
| 159 |
+
scores = {"Diabetes": 0.0, "Anemia": 0.0, "Heart Disease": 0.0, "Thrombocytopenia": 0.0, "Thalassemia": 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# Helper: check both abbreviated and normalized biomarker names
|
| 162 |
# Returns None when biomarker is not present (avoids false triggers)
|
|
|
|
| 226 |
else:
|
| 227 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 228 |
|
| 229 |
+
return {"disease": top_disease, "confidence": confidence, "probabilities": probabilities}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
|
| 232 |
def predict_disease_llm(biomarkers: dict[str, float], patient_context: dict) -> dict[str, Any]:
|
|
|
|
| 274 |
except Exception as e:
|
| 275 |
print(f"β οΈ LLM prediction failed ({e}), using rule-based fallback")
|
| 276 |
import traceback
|
| 277 |
+
|
| 278 |
traceback.print_exc()
|
| 279 |
return predict_disease_simple(biomarkers)
|
| 280 |
|
|
|
|
| 283 |
# Component 3: Conversational Formatter
|
| 284 |
# ============================================================================
|
| 285 |
|
| 286 |
+
|
| 287 |
def _coerce_to_dict(obj) -> dict:
|
| 288 |
"""Convert a Pydantic model or arbitrary object to a plain dict."""
|
| 289 |
if isinstance(obj, dict):
|
|
|
|
| 375 |
# Component 4: Helper Functions
|
| 376 |
# ============================================================================
|
| 377 |
|
| 378 |
+
|
| 379 |
def print_biomarker_help():
|
| 380 |
"""Print list of supported biomarkers"""
|
| 381 |
print("\nπ Supported Biomarkers (24 total):")
|
|
|
|
| 406 |
"Platelets": 220000,
|
| 407 |
"White Blood Cells": 7500,
|
| 408 |
"Systolic Blood Pressure": 145,
|
| 409 |
+
"Diastolic Blood Pressure": 92,
|
| 410 |
}
|
| 411 |
|
| 412 |
prediction = {
|
|
|
|
| 417 |
"Heart Disease": 0.08,
|
| 418 |
"Anemia": 0.03,
|
| 419 |
"Thrombocytopenia": 0.01,
|
| 420 |
+
"Thalassemia": 0.01,
|
| 421 |
+
},
|
| 422 |
}
|
| 423 |
|
| 424 |
patient_input = PatientInput(
|
| 425 |
biomarkers=example_biomarkers,
|
| 426 |
model_prediction=prediction,
|
| 427 |
+
patient_context={"age": 52, "gender": "male", "bmi": 31.2},
|
| 428 |
)
|
| 429 |
|
| 430 |
print("π Running analysis...\n")
|
| 431 |
result = guild.run(patient_input)
|
| 432 |
|
| 433 |
response = format_conversational(result.get("final_response", result), "there")
|
| 434 |
+
print("\n" + "=" * 70)
|
| 435 |
print("π€ RAG-BOT:")
|
| 436 |
+
print("=" * 70)
|
| 437 |
print(response)
|
| 438 |
+
print("=" * 70 + "\n")
|
| 439 |
|
| 440 |
|
| 441 |
def save_report(result: dict, biomarkers: dict):
|
|
|
|
| 444 |
|
| 445 |
# final_response is already a plain dict built by the synthesizer
|
| 446 |
final = result.get("final_response") or {}
|
| 447 |
+
disease = final.get("prediction_explanation", {}).get("primary_disease") or result.get("model_prediction", {}).get(
|
| 448 |
+
"disease", "unknown"
|
|
|
|
| 449 |
)
|
| 450 |
+
disease_safe = disease.replace(" ", "_").replace("/", "_")
|
| 451 |
filename = f"report_{disease_safe}_{timestamp}.json"
|
| 452 |
|
| 453 |
output_dir = Path("data/chat_reports")
|
|
|
|
| 461 |
return {k: _to_dict(v) for k, v in obj.items()}
|
| 462 |
if isinstance(obj, list):
|
| 463 |
return [_to_dict(i) for i in obj]
|
| 464 |
+
if hasattr(obj, "model_dump"): # Pydantic v2
|
| 465 |
return _to_dict(obj.model_dump())
|
| 466 |
+
if hasattr(obj, "dict"): # Pydantic v1
|
| 467 |
return _to_dict(obj.dict())
|
| 468 |
# Scalars and other primitives are returned as-is
|
| 469 |
return obj
|
|
|
|
| 476 |
"safety_alerts": _to_dict(result.get("safety_alerts", [])),
|
| 477 |
}
|
| 478 |
|
| 479 |
+
with open(filepath, "w") as f:
|
| 480 |
json.dump(report, f, indent=2)
|
| 481 |
|
| 482 |
print(f"β
Report saved to: {filepath}\n")
|
|
|
|
| 486 |
# Main Chat Interface
|
| 487 |
# ============================================================================
|
| 488 |
|
| 489 |
+
|
| 490 |
def chat_interface():
|
| 491 |
"""
|
| 492 |
Main interactive CLI chatbot for MediGuard AI RAG-Helper.
|
| 493 |
"""
|
| 494 |
# Print welcome banner
|
| 495 |
+
print("\n" + "=" * 70)
|
| 496 |
print("π€ MediGuard AI RAG-Helper - Interactive Chat")
|
| 497 |
+
print("=" * 70)
|
| 498 |
print("\nWelcome! I can help you understand your blood test results.\n")
|
| 499 |
print("You can:")
|
| 500 |
print(" 1. Describe your biomarkers (e.g., 'My glucose is 140, HbA1c is 7.5')")
|
| 501 |
print(" 2. Type 'example' to see a sample diabetes case")
|
| 502 |
print(" 3. Type 'help' for biomarker list")
|
| 503 |
print(" 4. Type 'quit' to exit\n")
|
| 504 |
+
print("=" * 70 + "\n")
|
| 505 |
|
| 506 |
# Initialize guild (one-time setup)
|
| 507 |
print("π§ Initializing medical knowledge system...")
|
|
|
|
| 529 |
continue
|
| 530 |
|
| 531 |
# Handle special commands
|
| 532 |
+
if user_input.lower() in ["quit", "exit", "q"]:
|
| 533 |
print("\nπ Thank you for using MediGuard AI. Stay healthy!")
|
| 534 |
break
|
| 535 |
|
| 536 |
+
if user_input.lower() == "help":
|
| 537 |
print_biomarker_help()
|
| 538 |
continue
|
| 539 |
|
| 540 |
+
if user_input.lower() == "example":
|
| 541 |
run_example_case(guild)
|
| 542 |
continue
|
| 543 |
|
|
|
|
| 568 |
patient_input = PatientInput(
|
| 569 |
biomarkers=biomarkers,
|
| 570 |
model_prediction=prediction,
|
| 571 |
+
patient_context=patient_context if patient_context else {"source": "chat"},
|
| 572 |
)
|
| 573 |
|
| 574 |
# Run full RAG workflow
|
|
|
|
| 581 |
response = format_conversational(result.get("final_response", result), user_name)
|
| 582 |
|
| 583 |
# Display response
|
| 584 |
+
print("\n" + "=" * 70)
|
| 585 |
print("π€ RAG-BOT:")
|
| 586 |
+
print("=" * 70)
|
| 587 |
print(response)
|
| 588 |
+
print("=" * 70 + "\n")
|
| 589 |
|
| 590 |
# Save to history
|
| 591 |
+
conversation_history.append(
|
| 592 |
+
{"user_input": user_input, "biomarkers": biomarkers, "prediction": prediction, "result": result}
|
| 593 |
+
)
|
|
|
|
|
|
|
|
|
|
| 594 |
|
| 595 |
# Ask if user wants to save report
|
| 596 |
save_choice = input("πΎ Save detailed report to file? (y/n): ").strip().lower()
|
| 597 |
+
if save_choice == "y":
|
| 598 |
save_report(result, biomarkers)
|
| 599 |
|
| 600 |
print("\nYou can:")
|
|
|
|
| 606 |
break
|
| 607 |
except Exception as e:
|
| 608 |
import traceback
|
| 609 |
+
|
| 610 |
traceback.print_exc()
|
| 611 |
print(f"\nβ Analysis failed: {e}")
|
| 612 |
print("\nThis might be due to:")
|
scripts/monitor_test.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
"""Monitor evolution test progress"""
|
|
|
|
| 2 |
import time
|
| 3 |
|
| 4 |
print("Monitoring evolution test... (Press Ctrl+C to stop)")
|
|
@@ -6,7 +7,7 @@ print("=" * 70)
|
|
| 6 |
|
| 7 |
for i in range(60): # Check for 5 minutes
|
| 8 |
time.sleep(5)
|
| 9 |
-
print(f"[{i*5}s] Test still running...")
|
| 10 |
|
| 11 |
print("\nTest should be complete or nearly complete.")
|
| 12 |
print("Check terminal output for results.")
|
|
|
|
| 1 |
"""Monitor evolution test progress"""
|
| 2 |
+
|
| 3 |
import time
|
| 4 |
|
| 5 |
print("Monitoring evolution test... (Press Ctrl+C to stop)")
|
|
|
|
| 7 |
|
| 8 |
for i in range(60): # Check for 5 minutes
|
| 9 |
time.sleep(5)
|
| 10 |
+
print(f"[{i * 5}s] Test still running...")
|
| 11 |
|
| 12 |
print("\nTest should be complete or nearly complete.")
|
| 13 |
print("Check terminal output for results.")
|
scripts/setup_embeddings.py
CHANGED
|
@@ -8,9 +8,9 @@ from pathlib import Path
|
|
| 8 |
def setup_google_api_key():
|
| 9 |
"""Interactive setup for Google API key"""
|
| 10 |
|
| 11 |
-
print("="*70)
|
| 12 |
print("Fast Embeddings Setup - Google Gemini API")
|
| 13 |
-
print("="*70)
|
| 14 |
|
| 15 |
print("\nWhy Google Gemini?")
|
| 16 |
print(" - 100x faster than local Ollama (2 mins vs 30+ mins)")
|
|
@@ -18,9 +18,9 @@ def setup_google_api_key():
|
|
| 18 |
print(" - High quality embeddings")
|
| 19 |
print(" - Automatic fallback to Ollama if unavailable")
|
| 20 |
|
| 21 |
-
print("\n" + "="*70)
|
| 22 |
print("Step 1: Get Your Free API Key")
|
| 23 |
-
print("="*70)
|
| 24 |
print("\n1. Open this URL in your browser:")
|
| 25 |
print(" https://aistudio.google.com/app/apikey")
|
| 26 |
print("\n2. Sign in with Google account")
|
|
@@ -38,7 +38,7 @@ def setup_google_api_key():
|
|
| 38 |
if not api_key.startswith("AIza"):
|
| 39 |
print("\nWarning: Key doesn't start with 'AIza'. Are you sure this is correct?")
|
| 40 |
confirm = input("Continue anyway? (y/n): ").strip().lower()
|
| 41 |
-
if confirm !=
|
| 42 |
return False
|
| 43 |
|
| 44 |
# Update .env file
|
|
@@ -52,28 +52,28 @@ def setup_google_api_key():
|
|
| 52 |
updated = False
|
| 53 |
for i, line in enumerate(lines):
|
| 54 |
if line.startswith("GOOGLE_API_KEY="):
|
| 55 |
-
lines[i] = f
|
| 56 |
updated = True
|
| 57 |
break
|
| 58 |
|
| 59 |
if not updated:
|
| 60 |
-
lines.insert(0, f
|
| 61 |
|
| 62 |
-
with open(env_path,
|
| 63 |
f.writelines(lines)
|
| 64 |
else:
|
| 65 |
# Create new .env file
|
| 66 |
-
with open(env_path,
|
| 67 |
-
f.write(f
|
| 68 |
|
| 69 |
print("\nAPI key saved to .env file!")
|
| 70 |
-
print("\n" + "="*70)
|
| 71 |
print("Step 2: Build Vector Store")
|
| 72 |
-
print("="*70)
|
| 73 |
print("\nRun this command:")
|
| 74 |
print(" python src/pdf_processor.py")
|
| 75 |
print("\nChoose option 1 (Google Gemini) when prompted.")
|
| 76 |
-
print("\n" + "="*70)
|
| 77 |
|
| 78 |
return True
|
| 79 |
|
|
|
|
| 8 |
def setup_google_api_key():
|
| 9 |
"""Interactive setup for Google API key"""
|
| 10 |
|
| 11 |
+
print("=" * 70)
|
| 12 |
print("Fast Embeddings Setup - Google Gemini API")
|
| 13 |
+
print("=" * 70)
|
| 14 |
|
| 15 |
print("\nWhy Google Gemini?")
|
| 16 |
print(" - 100x faster than local Ollama (2 mins vs 30+ mins)")
|
|
|
|
| 18 |
print(" - High quality embeddings")
|
| 19 |
print(" - Automatic fallback to Ollama if unavailable")
|
| 20 |
|
| 21 |
+
print("\n" + "=" * 70)
|
| 22 |
print("Step 1: Get Your Free API Key")
|
| 23 |
+
print("=" * 70)
|
| 24 |
print("\n1. Open this URL in your browser:")
|
| 25 |
print(" https://aistudio.google.com/app/apikey")
|
| 26 |
print("\n2. Sign in with Google account")
|
|
|
|
| 38 |
if not api_key.startswith("AIza"):
|
| 39 |
print("\nWarning: Key doesn't start with 'AIza'. Are you sure this is correct?")
|
| 40 |
confirm = input("Continue anyway? (y/n): ").strip().lower()
|
| 41 |
+
if confirm != "y":
|
| 42 |
return False
|
| 43 |
|
| 44 |
# Update .env file
|
|
|
|
| 52 |
updated = False
|
| 53 |
for i, line in enumerate(lines):
|
| 54 |
if line.startswith("GOOGLE_API_KEY="):
|
| 55 |
+
lines[i] = f"GOOGLE_API_KEY={api_key}\n"
|
| 56 |
updated = True
|
| 57 |
break
|
| 58 |
|
| 59 |
if not updated:
|
| 60 |
+
lines.insert(0, f"GOOGLE_API_KEY={api_key}\n")
|
| 61 |
|
| 62 |
+
with open(env_path, "w") as f:
|
| 63 |
f.writelines(lines)
|
| 64 |
else:
|
| 65 |
# Create new .env file
|
| 66 |
+
with open(env_path, "w") as f:
|
| 67 |
+
f.write(f"GOOGLE_API_KEY={api_key}\n")
|
| 68 |
|
| 69 |
print("\nAPI key saved to .env file!")
|
| 70 |
+
print("\n" + "=" * 70)
|
| 71 |
print("Step 2: Build Vector Store")
|
| 72 |
+
print("=" * 70)
|
| 73 |
print("\nRun this command:")
|
| 74 |
print(" python src/pdf_processor.py")
|
| 75 |
print("\nChoose option 1 (Google Gemini) when prompted.")
|
| 76 |
+
print("\n" + "=" * 70)
|
| 77 |
|
| 78 |
return True
|
| 79 |
|
scripts/test_chat_demo.py
CHANGED
|
@@ -10,16 +10,16 @@ test_cases = [
|
|
| 10 |
"help", # Show biomarker help
|
| 11 |
"glucose 185, HbA1c 8.2, cholesterol 235, triglycerides 210, HDL 38", # Diabetes case
|
| 12 |
"n", # Don't save report
|
| 13 |
-
"quit" # Exit
|
| 14 |
]
|
| 15 |
|
| 16 |
-
print("="*70)
|
| 17 |
print("CLI Chatbot Demo Test")
|
| 18 |
-
print("="*70)
|
| 19 |
print("\nThis will run the chatbot with pre-defined inputs:")
|
| 20 |
for i, case in enumerate(test_cases, 1):
|
| 21 |
print(f" {i}. {case}")
|
| 22 |
-
print("\n" + "="*70 + "\n")
|
| 23 |
|
| 24 |
# Prepare input string
|
| 25 |
input_str = "\n".join(test_cases) + "\n"
|
|
@@ -32,8 +32,8 @@ try:
|
|
| 32 |
capture_output=True,
|
| 33 |
text=True,
|
| 34 |
timeout=120,
|
| 35 |
-
encoding=
|
| 36 |
-
errors=
|
| 37 |
)
|
| 38 |
|
| 39 |
print("STDOUT:")
|
|
|
|
| 10 |
"help", # Show biomarker help
|
| 11 |
"glucose 185, HbA1c 8.2, cholesterol 235, triglycerides 210, HDL 38", # Diabetes case
|
| 12 |
"n", # Don't save report
|
| 13 |
+
"quit", # Exit
|
| 14 |
]
|
| 15 |
|
| 16 |
+
print("=" * 70)
|
| 17 |
print("CLI Chatbot Demo Test")
|
| 18 |
+
print("=" * 70)
|
| 19 |
print("\nThis will run the chatbot with pre-defined inputs:")
|
| 20 |
for i, case in enumerate(test_cases, 1):
|
| 21 |
print(f" {i}. {case}")
|
| 22 |
+
print("\n" + "=" * 70 + "\n")
|
| 23 |
|
| 24 |
# Prepare input string
|
| 25 |
input_str = "\n".join(test_cases) + "\n"
|
|
|
|
| 32 |
capture_output=True,
|
| 33 |
text=True,
|
| 34 |
timeout=120,
|
| 35 |
+
encoding="utf-8",
|
| 36 |
+
errors="replace",
|
| 37 |
)
|
| 38 |
|
| 39 |
print("STDOUT:")
|
scripts/test_extraction.py
CHANGED
|
@@ -16,13 +16,13 @@ test_inputs = [
|
|
| 16 |
"glucose=185, HbA1c=8.2, cholesterol=235, triglycerides=210, HDL=38",
|
| 17 |
]
|
| 18 |
|
| 19 |
-
print("="*70)
|
| 20 |
print("BIOMARKER EXTRACTION TEST")
|
| 21 |
-
print("="*70)
|
| 22 |
|
| 23 |
for i, test_input in enumerate(test_inputs, 1):
|
| 24 |
print(f"\n[Test {i}] Input: '{test_input}'")
|
| 25 |
-
print("-"*70)
|
| 26 |
|
| 27 |
biomarkers, context = extract_biomarkers(test_input)
|
| 28 |
|
|
@@ -44,6 +44,6 @@ for i, test_input in enumerate(test_inputs, 1):
|
|
| 44 |
|
| 45 |
print()
|
| 46 |
|
| 47 |
-
print("="*70)
|
| 48 |
print("TEST COMPLETE")
|
| 49 |
-
print("="*70)
|
|
|
|
| 16 |
"glucose=185, HbA1c=8.2, cholesterol=235, triglycerides=210, HDL=38",
|
| 17 |
]
|
| 18 |
|
| 19 |
+
print("=" * 70)
|
| 20 |
print("BIOMARKER EXTRACTION TEST")
|
| 21 |
+
print("=" * 70)
|
| 22 |
|
| 23 |
for i, test_input in enumerate(test_inputs, 1):
|
| 24 |
print(f"\n[Test {i}] Input: '{test_input}'")
|
| 25 |
+
print("-" * 70)
|
| 26 |
|
| 27 |
biomarkers, context = extract_biomarkers(test_input)
|
| 28 |
|
|
|
|
| 44 |
|
| 45 |
print()
|
| 46 |
|
| 47 |
+
print("=" * 70)
|
| 48 |
print("TEST COMPLETE")
|
| 49 |
+
print("=" * 70)
|
src/agents/biomarker_analyzer.py
CHANGED
|
@@ -3,7 +3,6 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Biomarker Analyzer Agent - Validates biomarker values and flags anomalies
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
|
| 7 |
from src.biomarker_validator import BiomarkerValidator
|
| 8 |
from src.llm_config import llm_config
|
| 9 |
from src.state import AgentOutput, BiomarkerFlag, GuildState
|
|
@@ -19,28 +18,26 @@ class BiomarkerAnalyzerAgent:
|
|
| 19 |
def analyze(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Main agent function to analyze biomarkers.
|
| 22 |
-
|
| 23 |
Args:
|
| 24 |
state: Current guild state with patient input
|
| 25 |
-
|
| 26 |
Returns:
|
| 27 |
Updated state with biomarker analysis
|
| 28 |
"""
|
| 29 |
-
print("\n" + "="*70)
|
| 30 |
print("EXECUTING: Biomarker Analyzer Agent")
|
| 31 |
-
print("="*70)
|
| 32 |
|
| 33 |
-
biomarkers = state[
|
| 34 |
-
patient_context = state.get(
|
| 35 |
-
gender = patient_context.get(
|
| 36 |
-
predicted_disease = state[
|
| 37 |
|
| 38 |
# Validate all biomarkers
|
| 39 |
print(f"\nValidating {len(biomarkers)} biomarkers...")
|
| 40 |
flags, alerts = self.validator.validate_all(
|
| 41 |
-
biomarkers=biomarkers,
|
| 42 |
-
gender=gender,
|
| 43 |
-
threshold_pct=state['sop'].biomarker_analyzer_threshold
|
| 44 |
)
|
| 45 |
|
| 46 |
# Get disease-relevant biomarkers
|
|
@@ -54,14 +51,11 @@ class BiomarkerAnalyzerAgent:
|
|
| 54 |
"safety_alerts": [alert.model_dump() for alert in alerts],
|
| 55 |
"relevant_biomarkers": relevant_biomarkers,
|
| 56 |
"summary": summary,
|
| 57 |
-
"validation_complete": True
|
| 58 |
}
|
| 59 |
|
| 60 |
# Create agent output
|
| 61 |
-
output = AgentOutput(
|
| 62 |
-
agent_name="Biomarker Analyzer",
|
| 63 |
-
findings=findings
|
| 64 |
-
)
|
| 65 |
|
| 66 |
# Update state
|
| 67 |
print("\nAnalysis complete:")
|
|
@@ -71,10 +65,10 @@ class BiomarkerAnalyzerAgent:
|
|
| 71 |
print(f" - {len(relevant_biomarkers)} disease-relevant biomarkers identified")
|
| 72 |
|
| 73 |
return {
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
}
|
| 79 |
|
| 80 |
def _generate_summary(
|
|
@@ -83,13 +77,13 @@ class BiomarkerAnalyzerAgent:
|
|
| 83 |
flags: list[BiomarkerFlag],
|
| 84 |
alerts: list,
|
| 85 |
relevant_biomarkers: list[str],
|
| 86 |
-
disease: str
|
| 87 |
) -> str:
|
| 88 |
"""Generate a concise summary of biomarker findings"""
|
| 89 |
|
| 90 |
# Count anomalies
|
| 91 |
-
critical = [f for f in flags if
|
| 92 |
-
high_low = [f for f in flags if f.status in [
|
| 93 |
|
| 94 |
prompt = f"""You are a medical data analyst. Provide a brief, clinical summary of these biomarker results.
|
| 95 |
|
|
|
|
| 3 |
Biomarker Analyzer Agent - Validates biomarker values and flags anomalies
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
from src.biomarker_validator import BiomarkerValidator
|
| 7 |
from src.llm_config import llm_config
|
| 8 |
from src.state import AgentOutput, BiomarkerFlag, GuildState
|
|
|
|
| 18 |
def analyze(self, state: GuildState) -> GuildState:
|
| 19 |
"""
|
| 20 |
Main agent function to analyze biomarkers.
|
| 21 |
+
|
| 22 |
Args:
|
| 23 |
state: Current guild state with patient input
|
| 24 |
+
|
| 25 |
Returns:
|
| 26 |
Updated state with biomarker analysis
|
| 27 |
"""
|
| 28 |
+
print("\n" + "=" * 70)
|
| 29 |
print("EXECUTING: Biomarker Analyzer Agent")
|
| 30 |
+
print("=" * 70)
|
| 31 |
|
| 32 |
+
biomarkers = state["patient_biomarkers"]
|
| 33 |
+
patient_context = state.get("patient_context", {})
|
| 34 |
+
gender = patient_context.get("gender") # None if not provided β uses non-gender-specific ranges
|
| 35 |
+
predicted_disease = state["model_prediction"]["disease"]
|
| 36 |
|
| 37 |
# Validate all biomarkers
|
| 38 |
print(f"\nValidating {len(biomarkers)} biomarkers...")
|
| 39 |
flags, alerts = self.validator.validate_all(
|
| 40 |
+
biomarkers=biomarkers, gender=gender, threshold_pct=state["sop"].biomarker_analyzer_threshold
|
|
|
|
|
|
|
| 41 |
)
|
| 42 |
|
| 43 |
# Get disease-relevant biomarkers
|
|
|
|
| 51 |
"safety_alerts": [alert.model_dump() for alert in alerts],
|
| 52 |
"relevant_biomarkers": relevant_biomarkers,
|
| 53 |
"summary": summary,
|
| 54 |
+
"validation_complete": True,
|
| 55 |
}
|
| 56 |
|
| 57 |
# Create agent output
|
| 58 |
+
output = AgentOutput(agent_name="Biomarker Analyzer", findings=findings)
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Update state
|
| 61 |
print("\nAnalysis complete:")
|
|
|
|
| 65 |
print(f" - {len(relevant_biomarkers)} disease-relevant biomarkers identified")
|
| 66 |
|
| 67 |
return {
|
| 68 |
+
"agent_outputs": [output],
|
| 69 |
+
"biomarker_flags": flags,
|
| 70 |
+
"safety_alerts": alerts,
|
| 71 |
+
"biomarker_analysis": findings,
|
| 72 |
}
|
| 73 |
|
| 74 |
def _generate_summary(
|
|
|
|
| 77 |
flags: list[BiomarkerFlag],
|
| 78 |
alerts: list,
|
| 79 |
relevant_biomarkers: list[str],
|
| 80 |
+
disease: str,
|
| 81 |
) -> str:
|
| 82 |
"""Generate a concise summary of biomarker findings"""
|
| 83 |
|
| 84 |
# Count anomalies
|
| 85 |
+
critical = [f for f in flags if "CRITICAL" in f.status]
|
| 86 |
+
high_low = [f for f in flags if f.status in ["HIGH", "LOW"]]
|
| 87 |
|
| 88 |
prompt = f"""You are a medical data analyst. Provide a brief, clinical summary of these biomarker results.
|
| 89 |
|
src/agents/biomarker_linker.py
CHANGED
|
@@ -3,8 +3,6 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Biomarker-Disease Linker Agent - Connects biomarker values to predicted disease
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
from src.llm_config import llm_config
|
| 9 |
from src.state import AgentOutput, GuildState, KeyDriver
|
| 10 |
|
|
@@ -15,7 +13,7 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 15 |
def __init__(self, retriever):
|
| 16 |
"""
|
| 17 |
Initialize with a retriever for biomarker-disease connections.
|
| 18 |
-
|
| 19 |
Args:
|
| 20 |
retriever: Vector store retriever for biomarker evidence
|
| 21 |
"""
|
|
@@ -25,32 +23,27 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 25 |
def link(self, state: GuildState) -> GuildState:
|
| 26 |
"""
|
| 27 |
Link biomarkers to disease prediction.
|
| 28 |
-
|
| 29 |
Args:
|
| 30 |
state: Current guild state
|
| 31 |
-
|
| 32 |
Returns:
|
| 33 |
Updated state with biomarker-disease links
|
| 34 |
"""
|
| 35 |
-
print("\n" + "="*70)
|
| 36 |
print("EXECUTING: Biomarker-Disease Linker Agent (RAG)")
|
| 37 |
-
print("="*70)
|
| 38 |
|
| 39 |
-
model_prediction = state[
|
| 40 |
-
disease = model_prediction[
|
| 41 |
-
biomarkers = state[
|
| 42 |
|
| 43 |
# Get biomarker analysis from previous agent
|
| 44 |
-
biomarker_analysis = state.get(
|
| 45 |
|
| 46 |
# Identify key drivers
|
| 47 |
print(f"\nIdentifying key drivers for {disease}...")
|
| 48 |
-
key_drivers, citations_missing = self._identify_key_drivers(
|
| 49 |
-
disease,
|
| 50 |
-
biomarkers,
|
| 51 |
-
biomarker_analysis,
|
| 52 |
-
state
|
| 53 |
-
)
|
| 54 |
|
| 55 |
print(f"Identified {len(key_drivers)} key biomarker drivers")
|
| 56 |
|
|
@@ -62,39 +55,29 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 62 |
"key_drivers": [kd.model_dump() for kd in key_drivers],
|
| 63 |
"total_drivers": len(key_drivers),
|
| 64 |
"feature_importance_calculated": True,
|
| 65 |
-
"citations_missing": citations_missing
|
| 66 |
-
}
|
| 67 |
)
|
| 68 |
|
| 69 |
# Update state
|
| 70 |
print("\nBiomarker-disease linking complete")
|
| 71 |
|
| 72 |
-
return {
|
| 73 |
|
| 74 |
def _identify_key_drivers(
|
| 75 |
-
self,
|
| 76 |
-
disease: str,
|
| 77 |
-
biomarkers: dict[str, float],
|
| 78 |
-
analysis: dict,
|
| 79 |
-
state: GuildState
|
| 80 |
) -> tuple[list[KeyDriver], bool]:
|
| 81 |
"""Identify which biomarkers are driving the disease prediction"""
|
| 82 |
|
| 83 |
# Get out-of-range biomarkers from analysis
|
| 84 |
-
flags = analysis.get(
|
| 85 |
-
abnormal_biomarkers = [
|
| 86 |
-
f for f in flags
|
| 87 |
-
if f['status'] != 'NORMAL'
|
| 88 |
-
]
|
| 89 |
|
| 90 |
# Get disease-relevant biomarkers
|
| 91 |
-
relevant = analysis.get(
|
| 92 |
|
| 93 |
# Focus on biomarkers that are both abnormal AND disease-relevant
|
| 94 |
-
key_biomarkers = [
|
| 95 |
-
f for f in abnormal_biomarkers
|
| 96 |
-
if f['name'] in relevant
|
| 97 |
-
]
|
| 98 |
|
| 99 |
# If no key biomarkers found, use top abnormal ones
|
| 100 |
if not key_biomarkers:
|
|
@@ -106,28 +89,19 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 106 |
key_drivers: list[KeyDriver] = []
|
| 107 |
citations_missing = False
|
| 108 |
for biomarker_flag in key_biomarkers[:5]: # Top 5
|
| 109 |
-
driver, driver_missing = self._create_key_driver(
|
| 110 |
-
biomarker_flag,
|
| 111 |
-
disease,
|
| 112 |
-
state
|
| 113 |
-
)
|
| 114 |
key_drivers.append(driver)
|
| 115 |
citations_missing = citations_missing or driver_missing
|
| 116 |
|
| 117 |
return key_drivers, citations_missing
|
| 118 |
|
| 119 |
-
def _create_key_driver(
|
| 120 |
-
self,
|
| 121 |
-
biomarker_flag: dict,
|
| 122 |
-
disease: str,
|
| 123 |
-
state: GuildState
|
| 124 |
-
) -> tuple[KeyDriver, bool]:
|
| 125 |
"""Create a KeyDriver object with evidence from RAG"""
|
| 126 |
|
| 127 |
-
name = biomarker_flag[
|
| 128 |
-
value = biomarker_flag[
|
| 129 |
-
unit = biomarker_flag[
|
| 130 |
-
status = biomarker_flag[
|
| 131 |
|
| 132 |
# Retrieve evidence linking this biomarker to the disease
|
| 133 |
query = f"How does {name} relate to {disease}? What does {status} {name} indicate?"
|
|
@@ -135,7 +109,7 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 135 |
citations_missing = False
|
| 136 |
try:
|
| 137 |
docs = self.retriever.invoke(query)
|
| 138 |
-
if state[
|
| 139 |
evidence_text = "Insufficient evidence available in the knowledge base."
|
| 140 |
contribution = "Unknown"
|
| 141 |
citations_missing = True
|
|
@@ -149,16 +123,14 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 149 |
citations_missing = True
|
| 150 |
|
| 151 |
# Generate explanation using LLM
|
| 152 |
-
explanation = self._generate_explanation(
|
| 153 |
-
name, value, unit, status, disease, evidence_text
|
| 154 |
-
)
|
| 155 |
|
| 156 |
driver = KeyDriver(
|
| 157 |
biomarker=name,
|
| 158 |
value=value,
|
| 159 |
contribution=contribution,
|
| 160 |
explanation=explanation,
|
| 161 |
-
evidence=evidence_text[:500] # Truncate long evidence
|
| 162 |
)
|
| 163 |
|
| 164 |
return driver, citations_missing
|
|
@@ -173,10 +145,9 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 173 |
for doc in docs[:2]: # Top 2 docs
|
| 174 |
content = doc.page_content
|
| 175 |
# Extract sentences mentioning the biomarker
|
| 176 |
-
sentences = content.split(
|
| 177 |
relevant_sentences = [
|
| 178 |
-
s.strip() for s in sentences
|
| 179 |
-
if biomarker.lower() in s.lower() or disease.lower() in s.lower()
|
| 180 |
]
|
| 181 |
evidence.extend(relevant_sentences[:2])
|
| 182 |
|
|
@@ -184,12 +155,12 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 184 |
|
| 185 |
def _estimate_contribution(self, biomarker_flag: dict, doc_count: int) -> str:
|
| 186 |
"""Estimate the contribution percentage (simplified)"""
|
| 187 |
-
status = biomarker_flag[
|
| 188 |
|
| 189 |
# Simple heuristic based on severity
|
| 190 |
-
if
|
| 191 |
base = 40
|
| 192 |
-
elif status in [
|
| 193 |
base = 25
|
| 194 |
else:
|
| 195 |
base = 10
|
|
@@ -201,13 +172,7 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 201 |
return f"{total}%"
|
| 202 |
|
| 203 |
def _generate_explanation(
|
| 204 |
-
self,
|
| 205 |
-
biomarker: str,
|
| 206 |
-
value: float,
|
| 207 |
-
unit: str,
|
| 208 |
-
status: str,
|
| 209 |
-
disease: str,
|
| 210 |
-
evidence: str
|
| 211 |
) -> str:
|
| 212 |
"""Generate patient-friendly explanation"""
|
| 213 |
|
|
|
|
| 3 |
Biomarker-Disease Linker Agent - Connects biomarker values to predicted disease
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from src.llm_config import llm_config
|
| 7 |
from src.state import AgentOutput, GuildState, KeyDriver
|
| 8 |
|
|
|
|
| 13 |
def __init__(self, retriever):
|
| 14 |
"""
|
| 15 |
Initialize with a retriever for biomarker-disease connections.
|
| 16 |
+
|
| 17 |
Args:
|
| 18 |
retriever: Vector store retriever for biomarker evidence
|
| 19 |
"""
|
|
|
|
| 23 |
def link(self, state: GuildState) -> GuildState:
|
| 24 |
"""
|
| 25 |
Link biomarkers to disease prediction.
|
| 26 |
+
|
| 27 |
Args:
|
| 28 |
state: Current guild state
|
| 29 |
+
|
| 30 |
Returns:
|
| 31 |
Updated state with biomarker-disease links
|
| 32 |
"""
|
| 33 |
+
print("\n" + "=" * 70)
|
| 34 |
print("EXECUTING: Biomarker-Disease Linker Agent (RAG)")
|
| 35 |
+
print("=" * 70)
|
| 36 |
|
| 37 |
+
model_prediction = state["model_prediction"]
|
| 38 |
+
disease = model_prediction["disease"]
|
| 39 |
+
biomarkers = state["patient_biomarkers"]
|
| 40 |
|
| 41 |
# Get biomarker analysis from previous agent
|
| 42 |
+
biomarker_analysis = state.get("biomarker_analysis") or {}
|
| 43 |
|
| 44 |
# Identify key drivers
|
| 45 |
print(f"\nIdentifying key drivers for {disease}...")
|
| 46 |
+
key_drivers, citations_missing = self._identify_key_drivers(disease, biomarkers, biomarker_analysis, state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
print(f"Identified {len(key_drivers)} key biomarker drivers")
|
| 49 |
|
|
|
|
| 55 |
"key_drivers": [kd.model_dump() for kd in key_drivers],
|
| 56 |
"total_drivers": len(key_drivers),
|
| 57 |
"feature_importance_calculated": True,
|
| 58 |
+
"citations_missing": citations_missing,
|
| 59 |
+
},
|
| 60 |
)
|
| 61 |
|
| 62 |
# Update state
|
| 63 |
print("\nBiomarker-disease linking complete")
|
| 64 |
|
| 65 |
+
return {"agent_outputs": [output]}
|
| 66 |
|
| 67 |
def _identify_key_drivers(
|
| 68 |
+
self, disease: str, biomarkers: dict[str, float], analysis: dict, state: GuildState
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
) -> tuple[list[KeyDriver], bool]:
|
| 70 |
"""Identify which biomarkers are driving the disease prediction"""
|
| 71 |
|
| 72 |
# Get out-of-range biomarkers from analysis
|
| 73 |
+
flags = analysis.get("biomarker_flags", [])
|
| 74 |
+
abnormal_biomarkers = [f for f in flags if f["status"] != "NORMAL"]
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Get disease-relevant biomarkers
|
| 77 |
+
relevant = analysis.get("relevant_biomarkers", [])
|
| 78 |
|
| 79 |
# Focus on biomarkers that are both abnormal AND disease-relevant
|
| 80 |
+
key_biomarkers = [f for f in abnormal_biomarkers if f["name"] in relevant]
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# If no key biomarkers found, use top abnormal ones
|
| 83 |
if not key_biomarkers:
|
|
|
|
| 89 |
key_drivers: list[KeyDriver] = []
|
| 90 |
citations_missing = False
|
| 91 |
for biomarker_flag in key_biomarkers[:5]: # Top 5
|
| 92 |
+
driver, driver_missing = self._create_key_driver(biomarker_flag, disease, state)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
key_drivers.append(driver)
|
| 94 |
citations_missing = citations_missing or driver_missing
|
| 95 |
|
| 96 |
return key_drivers, citations_missing
|
| 97 |
|
| 98 |
+
def _create_key_driver(self, biomarker_flag: dict, disease: str, state: GuildState) -> tuple[KeyDriver, bool]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
"""Create a KeyDriver object with evidence from RAG"""
|
| 100 |
|
| 101 |
+
name = biomarker_flag["name"]
|
| 102 |
+
value = biomarker_flag["value"]
|
| 103 |
+
unit = biomarker_flag["unit"]
|
| 104 |
+
status = biomarker_flag["status"]
|
| 105 |
|
| 106 |
# Retrieve evidence linking this biomarker to the disease
|
| 107 |
query = f"How does {name} relate to {disease}? What does {status} {name} indicate?"
|
|
|
|
| 109 |
citations_missing = False
|
| 110 |
try:
|
| 111 |
docs = self.retriever.invoke(query)
|
| 112 |
+
if state["sop"].require_pdf_citations and not docs:
|
| 113 |
evidence_text = "Insufficient evidence available in the knowledge base."
|
| 114 |
contribution = "Unknown"
|
| 115 |
citations_missing = True
|
|
|
|
| 123 |
citations_missing = True
|
| 124 |
|
| 125 |
# Generate explanation using LLM
|
| 126 |
+
explanation = self._generate_explanation(name, value, unit, status, disease, evidence_text)
|
|
|
|
|
|
|
| 127 |
|
| 128 |
driver = KeyDriver(
|
| 129 |
biomarker=name,
|
| 130 |
value=value,
|
| 131 |
contribution=contribution,
|
| 132 |
explanation=explanation,
|
| 133 |
+
evidence=evidence_text[:500], # Truncate long evidence
|
| 134 |
)
|
| 135 |
|
| 136 |
return driver, citations_missing
|
|
|
|
| 145 |
for doc in docs[:2]: # Top 2 docs
|
| 146 |
content = doc.page_content
|
| 147 |
# Extract sentences mentioning the biomarker
|
| 148 |
+
sentences = content.split(".")
|
| 149 |
relevant_sentences = [
|
| 150 |
+
s.strip() for s in sentences if biomarker.lower() in s.lower() or disease.lower() in s.lower()
|
|
|
|
| 151 |
]
|
| 152 |
evidence.extend(relevant_sentences[:2])
|
| 153 |
|
|
|
|
| 155 |
|
| 156 |
def _estimate_contribution(self, biomarker_flag: dict, doc_count: int) -> str:
|
| 157 |
"""Estimate the contribution percentage (simplified)"""
|
| 158 |
+
status = biomarker_flag["status"]
|
| 159 |
|
| 160 |
# Simple heuristic based on severity
|
| 161 |
+
if "CRITICAL" in status:
|
| 162 |
base = 40
|
| 163 |
+
elif status in ["HIGH", "LOW"]:
|
| 164 |
base = 25
|
| 165 |
else:
|
| 166 |
base = 10
|
|
|
|
| 172 |
return f"{total}%"
|
| 173 |
|
| 174 |
def _generate_explanation(
|
| 175 |
+
self, biomarker: str, value: float, unit: str, status: str, disease: str, evidence: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
) -> str:
|
| 177 |
"""Generate patient-friendly explanation"""
|
| 178 |
|
src/agents/clinical_guidelines.py
CHANGED
|
@@ -17,7 +17,7 @@ class ClinicalGuidelinesAgent:
|
|
| 17 |
def __init__(self, retriever):
|
| 18 |
"""
|
| 19 |
Initialize with a retriever for clinical guidelines.
|
| 20 |
-
|
| 21 |
Args:
|
| 22 |
retriever: Vector store retriever for guidelines documents
|
| 23 |
"""
|
|
@@ -27,24 +27,24 @@ class ClinicalGuidelinesAgent:
|
|
| 27 |
def recommend(self, state: GuildState) -> GuildState:
|
| 28 |
"""
|
| 29 |
Retrieve clinical guidelines and generate recommendations.
|
| 30 |
-
|
| 31 |
Args:
|
| 32 |
state: Current guild state
|
| 33 |
-
|
| 34 |
Returns:
|
| 35 |
Updated state with clinical recommendations
|
| 36 |
"""
|
| 37 |
-
print("\n" + "="*70)
|
| 38 |
print("EXECUTING: Clinical Guidelines Agent (RAG)")
|
| 39 |
-
print("="*70)
|
| 40 |
|
| 41 |
-
model_prediction = state[
|
| 42 |
-
disease = model_prediction[
|
| 43 |
-
confidence = model_prediction[
|
| 44 |
|
| 45 |
# Get biomarker analysis
|
| 46 |
-
biomarker_analysis = state.get(
|
| 47 |
-
safety_alerts = biomarker_analysis.get(
|
| 48 |
|
| 49 |
# Retrieve guidelines
|
| 50 |
print(f"\nRetrieving clinical guidelines for {disease}...")
|
|
@@ -57,36 +57,30 @@ class ClinicalGuidelinesAgent:
|
|
| 57 |
print(f"Retrieved {len(docs)} guideline documents")
|
| 58 |
|
| 59 |
# Generate recommendations
|
| 60 |
-
if state[
|
| 61 |
recommendations = {
|
| 62 |
"immediate_actions": [
|
| 63 |
"Insufficient evidence available in the knowledge base. Please consult a healthcare provider."
|
| 64 |
],
|
| 65 |
"lifestyle_changes": [],
|
| 66 |
"monitoring": [],
|
| 67 |
-
"citations": []
|
| 68 |
}
|
| 69 |
else:
|
| 70 |
-
recommendations = self._generate_recommendations(
|
| 71 |
-
disease,
|
| 72 |
-
docs,
|
| 73 |
-
safety_alerts,
|
| 74 |
-
confidence,
|
| 75 |
-
state
|
| 76 |
-
)
|
| 77 |
|
| 78 |
# Create agent output
|
| 79 |
output = AgentOutput(
|
| 80 |
agent_name="Clinical Guidelines",
|
| 81 |
findings={
|
| 82 |
"disease": disease,
|
| 83 |
-
"immediate_actions": recommendations[
|
| 84 |
-
"lifestyle_changes": recommendations[
|
| 85 |
-
"monitoring": recommendations[
|
| 86 |
-
"guideline_citations": recommendations[
|
| 87 |
"safety_priority": len(safety_alerts) > 0,
|
| 88 |
-
"citations_missing": state[
|
| 89 |
-
}
|
| 90 |
)
|
| 91 |
|
| 92 |
# Update state
|
|
@@ -95,23 +89,17 @@ class ClinicalGuidelinesAgent:
|
|
| 95 |
print(f" - Lifestyle changes: {len(recommendations['lifestyle_changes'])}")
|
| 96 |
print(f" - Monitoring recommendations: {len(recommendations['monitoring'])}")
|
| 97 |
|
| 98 |
-
return {
|
| 99 |
|
| 100 |
def _generate_recommendations(
|
| 101 |
-
self,
|
| 102 |
-
disease: str,
|
| 103 |
-
docs: list,
|
| 104 |
-
safety_alerts: list,
|
| 105 |
-
confidence: float,
|
| 106 |
-
state: GuildState
|
| 107 |
) -> dict:
|
| 108 |
"""Generate structured recommendations using LLM and guidelines"""
|
| 109 |
|
| 110 |
# Format retrieved guidelines
|
| 111 |
-
guidelines_context = "\n\n---\n\n".join(
|
| 112 |
-
f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
|
| 113 |
-
|
| 114 |
-
])
|
| 115 |
|
| 116 |
# Build safety context
|
| 117 |
safety_context = ""
|
|
@@ -120,8 +108,11 @@ class ClinicalGuidelinesAgent:
|
|
| 120 |
for alert in safety_alerts[:3]:
|
| 121 |
safety_context += f"- {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}\n"
|
| 122 |
|
| 123 |
-
prompt = ChatPromptTemplate.from_messages(
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
| 125 |
Based on clinical practice guidelines, provide actionable recommendations for patient self-assessment.
|
| 126 |
|
| 127 |
Structure your response with these sections:
|
|
@@ -130,26 +121,33 @@ class ClinicalGuidelinesAgent:
|
|
| 130 |
3. MONITORING: What to track and how often
|
| 131 |
|
| 132 |
Make recommendations specific, actionable, and guideline-aligned.
|
| 133 |
-
Always emphasize consulting healthcare professionals for diagnosis and treatment."""
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
Prediction Confidence: {confidence:.1%}
|
| 136 |
{safety_context}
|
| 137 |
|
| 138 |
Clinical Guidelines Context:
|
| 139 |
{guidelines}
|
| 140 |
|
| 141 |
-
Please provide structured recommendations for patient self-assessment."""
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
|
| 144 |
chain = prompt | self.llm
|
| 145 |
|
| 146 |
try:
|
| 147 |
-
response = chain.invoke(
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
|
| 154 |
recommendations = self._parse_recommendations(response.content)
|
| 155 |
|
|
@@ -158,82 +156,76 @@ class ClinicalGuidelinesAgent:
|
|
| 158 |
recommendations = self._get_default_recommendations(disease, safety_alerts)
|
| 159 |
|
| 160 |
# Add citations
|
| 161 |
-
recommendations[
|
| 162 |
|
| 163 |
return recommendations
|
| 164 |
|
| 165 |
def _parse_recommendations(self, content: str) -> dict:
|
| 166 |
"""Parse LLM response into structured recommendations"""
|
| 167 |
-
recommendations = {
|
| 168 |
-
"immediate_actions": [],
|
| 169 |
-
"lifestyle_changes": [],
|
| 170 |
-
"monitoring": []
|
| 171 |
-
}
|
| 172 |
|
| 173 |
current_section = None
|
| 174 |
-
lines = content.split(
|
| 175 |
|
| 176 |
for line in lines:
|
| 177 |
line_stripped = line.strip()
|
| 178 |
line_upper = line_stripped.upper()
|
| 179 |
|
| 180 |
# Detect section headers
|
| 181 |
-
if
|
| 182 |
-
current_section =
|
| 183 |
-
elif
|
| 184 |
-
current_section =
|
| 185 |
-
elif
|
| 186 |
-
current_section =
|
| 187 |
# Add bullet points or numbered items
|
| 188 |
elif current_section and line_stripped:
|
| 189 |
# Remove bullet points and numbers
|
| 190 |
-
cleaned = line_stripped.lstrip(
|
| 191 |
if cleaned and len(cleaned) > 10: # Minimum length filter
|
| 192 |
recommendations[current_section].append(cleaned)
|
| 193 |
|
| 194 |
# If parsing failed, create default structure
|
| 195 |
if not any(recommendations.values()):
|
| 196 |
-
sentences = content.split(
|
| 197 |
-
recommendations[
|
| 198 |
-
recommendations[
|
| 199 |
-
recommendations[
|
| 200 |
|
| 201 |
return recommendations
|
| 202 |
|
| 203 |
def _get_default_recommendations(self, disease: str, safety_alerts: list) -> dict:
|
| 204 |
"""Provide default recommendations if LLM fails"""
|
| 205 |
-
recommendations = {
|
| 206 |
-
"immediate_actions": [],
|
| 207 |
-
"lifestyle_changes": [],
|
| 208 |
-
"monitoring": []
|
| 209 |
-
}
|
| 210 |
|
| 211 |
# Add safety-based immediate actions
|
| 212 |
if safety_alerts:
|
| 213 |
-
recommendations[
|
| 214 |
"Consult healthcare provider immediately regarding critical biomarker values"
|
| 215 |
)
|
| 216 |
-
recommendations[
|
| 217 |
-
"Bring this report and recent lab results to your appointment"
|
| 218 |
-
)
|
| 219 |
else:
|
| 220 |
-
recommendations[
|
| 221 |
f"Schedule appointment with healthcare provider to discuss {disease} findings"
|
| 222 |
)
|
| 223 |
|
| 224 |
# Generic lifestyle changes
|
| 225 |
-
recommendations[
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
|
| 231 |
# Generic monitoring
|
| 232 |
-
recommendations[
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
| 237 |
|
| 238 |
return recommendations
|
| 239 |
|
|
@@ -242,10 +234,10 @@ class ClinicalGuidelinesAgent:
|
|
| 242 |
citations = []
|
| 243 |
|
| 244 |
for doc in docs:
|
| 245 |
-
source = doc.metadata.get(
|
| 246 |
|
| 247 |
# Clean up source path
|
| 248 |
-
if
|
| 249 |
source = Path(source).name
|
| 250 |
|
| 251 |
citations.append(source)
|
|
|
|
| 17 |
def __init__(self, retriever):
|
| 18 |
"""
|
| 19 |
Initialize with a retriever for clinical guidelines.
|
| 20 |
+
|
| 21 |
Args:
|
| 22 |
retriever: Vector store retriever for guidelines documents
|
| 23 |
"""
|
|
|
|
| 27 |
def recommend(self, state: GuildState) -> GuildState:
|
| 28 |
"""
|
| 29 |
Retrieve clinical guidelines and generate recommendations.
|
| 30 |
+
|
| 31 |
Args:
|
| 32 |
state: Current guild state
|
| 33 |
+
|
| 34 |
Returns:
|
| 35 |
Updated state with clinical recommendations
|
| 36 |
"""
|
| 37 |
+
print("\n" + "=" * 70)
|
| 38 |
print("EXECUTING: Clinical Guidelines Agent (RAG)")
|
| 39 |
+
print("=" * 70)
|
| 40 |
|
| 41 |
+
model_prediction = state["model_prediction"]
|
| 42 |
+
disease = model_prediction["disease"]
|
| 43 |
+
confidence = model_prediction["confidence"]
|
| 44 |
|
| 45 |
# Get biomarker analysis
|
| 46 |
+
biomarker_analysis = state.get("biomarker_analysis") or {}
|
| 47 |
+
safety_alerts = biomarker_analysis.get("safety_alerts", [])
|
| 48 |
|
| 49 |
# Retrieve guidelines
|
| 50 |
print(f"\nRetrieving clinical guidelines for {disease}...")
|
|
|
|
| 57 |
print(f"Retrieved {len(docs)} guideline documents")
|
| 58 |
|
| 59 |
# Generate recommendations
|
| 60 |
+
if state["sop"].require_pdf_citations and not docs:
|
| 61 |
recommendations = {
|
| 62 |
"immediate_actions": [
|
| 63 |
"Insufficient evidence available in the knowledge base. Please consult a healthcare provider."
|
| 64 |
],
|
| 65 |
"lifestyle_changes": [],
|
| 66 |
"monitoring": [],
|
| 67 |
+
"citations": [],
|
| 68 |
}
|
| 69 |
else:
|
| 70 |
+
recommendations = self._generate_recommendations(disease, docs, safety_alerts, confidence, state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Create agent output
|
| 73 |
output = AgentOutput(
|
| 74 |
agent_name="Clinical Guidelines",
|
| 75 |
findings={
|
| 76 |
"disease": disease,
|
| 77 |
+
"immediate_actions": recommendations["immediate_actions"],
|
| 78 |
+
"lifestyle_changes": recommendations["lifestyle_changes"],
|
| 79 |
+
"monitoring": recommendations["monitoring"],
|
| 80 |
+
"guideline_citations": recommendations["citations"],
|
| 81 |
"safety_priority": len(safety_alerts) > 0,
|
| 82 |
+
"citations_missing": state["sop"].require_pdf_citations and not docs,
|
| 83 |
+
},
|
| 84 |
)
|
| 85 |
|
| 86 |
# Update state
|
|
|
|
| 89 |
print(f" - Lifestyle changes: {len(recommendations['lifestyle_changes'])}")
|
| 90 |
print(f" - Monitoring recommendations: {len(recommendations['monitoring'])}")
|
| 91 |
|
| 92 |
+
return {"agent_outputs": [output]}
|
| 93 |
|
| 94 |
def _generate_recommendations(
|
| 95 |
+
self, disease: str, docs: list, safety_alerts: list, confidence: float, state: GuildState
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
) -> dict:
|
| 97 |
"""Generate structured recommendations using LLM and guidelines"""
|
| 98 |
|
| 99 |
# Format retrieved guidelines
|
| 100 |
+
guidelines_context = "\n\n---\n\n".join(
|
| 101 |
+
[f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}" for doc in docs]
|
| 102 |
+
)
|
|
|
|
| 103 |
|
| 104 |
# Build safety context
|
| 105 |
safety_context = ""
|
|
|
|
| 108 |
for alert in safety_alerts[:3]:
|
| 109 |
safety_context += f"- {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}\n"
|
| 110 |
|
| 111 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 112 |
+
[
|
| 113 |
+
(
|
| 114 |
+
"system",
|
| 115 |
+
"""You are a clinical decision support system providing evidence-based recommendations.
|
| 116 |
Based on clinical practice guidelines, provide actionable recommendations for patient self-assessment.
|
| 117 |
|
| 118 |
Structure your response with these sections:
|
|
|
|
| 121 |
3. MONITORING: What to track and how often
|
| 122 |
|
| 123 |
Make recommendations specific, actionable, and guideline-aligned.
|
| 124 |
+
Always emphasize consulting healthcare professionals for diagnosis and treatment.""",
|
| 125 |
+
),
|
| 126 |
+
(
|
| 127 |
+
"human",
|
| 128 |
+
"""Disease: {disease}
|
| 129 |
Prediction Confidence: {confidence:.1%}
|
| 130 |
{safety_context}
|
| 131 |
|
| 132 |
Clinical Guidelines Context:
|
| 133 |
{guidelines}
|
| 134 |
|
| 135 |
+
Please provide structured recommendations for patient self-assessment.""",
|
| 136 |
+
),
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
|
| 140 |
chain = prompt | self.llm
|
| 141 |
|
| 142 |
try:
|
| 143 |
+
response = chain.invoke(
|
| 144 |
+
{
|
| 145 |
+
"disease": disease,
|
| 146 |
+
"confidence": confidence,
|
| 147 |
+
"safety_context": safety_context,
|
| 148 |
+
"guidelines": guidelines_context,
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
|
| 152 |
recommendations = self._parse_recommendations(response.content)
|
| 153 |
|
|
|
|
| 156 |
recommendations = self._get_default_recommendations(disease, safety_alerts)
|
| 157 |
|
| 158 |
# Add citations
|
| 159 |
+
recommendations["citations"] = self._extract_citations(docs)
|
| 160 |
|
| 161 |
return recommendations
|
| 162 |
|
| 163 |
def _parse_recommendations(self, content: str) -> dict:
|
| 164 |
"""Parse LLM response into structured recommendations"""
|
| 165 |
+
recommendations = {"immediate_actions": [], "lifestyle_changes": [], "monitoring": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
current_section = None
|
| 168 |
+
lines = content.split("\n")
|
| 169 |
|
| 170 |
for line in lines:
|
| 171 |
line_stripped = line.strip()
|
| 172 |
line_upper = line_stripped.upper()
|
| 173 |
|
| 174 |
# Detect section headers
|
| 175 |
+
if "IMMEDIATE" in line_upper or "URGENT" in line_upper:
|
| 176 |
+
current_section = "immediate_actions"
|
| 177 |
+
elif "LIFESTYLE" in line_upper or "CHANGES" in line_upper or "DIET" in line_upper:
|
| 178 |
+
current_section = "lifestyle_changes"
|
| 179 |
+
elif "MONITORING" in line_upper or "TRACK" in line_upper:
|
| 180 |
+
current_section = "monitoring"
|
| 181 |
# Add bullet points or numbered items
|
| 182 |
elif current_section and line_stripped:
|
| 183 |
# Remove bullet points and numbers
|
| 184 |
+
cleaned = line_stripped.lstrip("β’-*0123456789. ")
|
| 185 |
if cleaned and len(cleaned) > 10: # Minimum length filter
|
| 186 |
recommendations[current_section].append(cleaned)
|
| 187 |
|
| 188 |
# If parsing failed, create default structure
|
| 189 |
if not any(recommendations.values()):
|
| 190 |
+
sentences = content.split(".")
|
| 191 |
+
recommendations["immediate_actions"] = [s.strip() for s in sentences[:2] if s.strip()]
|
| 192 |
+
recommendations["lifestyle_changes"] = [s.strip() for s in sentences[2:4] if s.strip()]
|
| 193 |
+
recommendations["monitoring"] = [s.strip() for s in sentences[4:6] if s.strip()]
|
| 194 |
|
| 195 |
return recommendations
|
| 196 |
|
| 197 |
def _get_default_recommendations(self, disease: str, safety_alerts: list) -> dict:
|
| 198 |
"""Provide default recommendations if LLM fails"""
|
| 199 |
+
recommendations = {"immediate_actions": [], "lifestyle_changes": [], "monitoring": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
# Add safety-based immediate actions
|
| 202 |
if safety_alerts:
|
| 203 |
+
recommendations["immediate_actions"].append(
|
| 204 |
"Consult healthcare provider immediately regarding critical biomarker values"
|
| 205 |
)
|
| 206 |
+
recommendations["immediate_actions"].append("Bring this report and recent lab results to your appointment")
|
|
|
|
|
|
|
| 207 |
else:
|
| 208 |
+
recommendations["immediate_actions"].append(
|
| 209 |
f"Schedule appointment with healthcare provider to discuss {disease} findings"
|
| 210 |
)
|
| 211 |
|
| 212 |
# Generic lifestyle changes
|
| 213 |
+
recommendations["lifestyle_changes"].extend(
|
| 214 |
+
[
|
| 215 |
+
"Follow a balanced, nutrient-rich diet as recommended by healthcare provider",
|
| 216 |
+
"Maintain regular physical activity appropriate for your health status",
|
| 217 |
+
"Track symptoms and biomarker trends over time",
|
| 218 |
+
]
|
| 219 |
+
)
|
| 220 |
|
| 221 |
# Generic monitoring
|
| 222 |
+
recommendations["monitoring"].extend(
|
| 223 |
+
[
|
| 224 |
+
f"Regular monitoring of {disease}-related biomarkers as advised by physician",
|
| 225 |
+
"Keep a health journal tracking symptoms, diet, and activities",
|
| 226 |
+
"Schedule follow-up appointments as recommended",
|
| 227 |
+
]
|
| 228 |
+
)
|
| 229 |
|
| 230 |
return recommendations
|
| 231 |
|
|
|
|
| 234 |
citations = []
|
| 235 |
|
| 236 |
for doc in docs:
|
| 237 |
+
source = doc.metadata.get("source", "Unknown")
|
| 238 |
|
| 239 |
# Clean up source path
|
| 240 |
+
if "\\" in source or "/" in source:
|
| 241 |
source = Path(source).name
|
| 242 |
|
| 243 |
citations.append(source)
|
src/agents/confidence_assessor.py
CHANGED
|
@@ -19,58 +19,42 @@ class ConfidenceAssessorAgent:
|
|
| 19 |
def assess(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Assess prediction confidence and identify limitations.
|
| 22 |
-
|
| 23 |
Args:
|
| 24 |
state: Current guild state
|
| 25 |
-
|
| 26 |
Returns:
|
| 27 |
Updated state with confidence assessment
|
| 28 |
"""
|
| 29 |
-
print("\n" + "="*70)
|
| 30 |
print("EXECUTING: Confidence Assessor Agent")
|
| 31 |
-
print("="*70)
|
| 32 |
|
| 33 |
-
model_prediction = state[
|
| 34 |
-
disease = model_prediction[
|
| 35 |
-
ml_confidence = model_prediction[
|
| 36 |
-
probabilities = model_prediction.get(
|
| 37 |
-
biomarkers = state[
|
| 38 |
|
| 39 |
# Collect previous agent findings
|
| 40 |
-
biomarker_analysis = state.get(
|
| 41 |
disease_explanation = self._get_agent_findings(state, "Disease Explainer")
|
| 42 |
linker_findings = self._get_agent_findings(state, "Biomarker-Disease Linker")
|
| 43 |
|
| 44 |
print(f"\nAssessing confidence for {disease} prediction...")
|
| 45 |
|
| 46 |
# Evaluate evidence strength
|
| 47 |
-
evidence_strength = self._evaluate_evidence_strength(
|
| 48 |
-
biomarker_analysis,
|
| 49 |
-
disease_explanation,
|
| 50 |
-
linker_findings
|
| 51 |
-
)
|
| 52 |
|
| 53 |
# Identify limitations
|
| 54 |
-
limitations = self._identify_limitations(
|
| 55 |
-
biomarkers,
|
| 56 |
-
biomarker_analysis,
|
| 57 |
-
probabilities
|
| 58 |
-
)
|
| 59 |
|
| 60 |
# Calculate aggregate reliability
|
| 61 |
-
reliability = self._calculate_reliability(
|
| 62 |
-
ml_confidence,
|
| 63 |
-
evidence_strength,
|
| 64 |
-
len(limitations)
|
| 65 |
-
)
|
| 66 |
|
| 67 |
# Generate assessment summary
|
| 68 |
assessment_summary = self._generate_assessment(
|
| 69 |
-
disease,
|
| 70 |
-
ml_confidence,
|
| 71 |
-
reliability,
|
| 72 |
-
evidence_strength,
|
| 73 |
-
limitations
|
| 74 |
)
|
| 75 |
|
| 76 |
# Create agent output
|
|
@@ -83,8 +67,8 @@ class ConfidenceAssessorAgent:
|
|
| 83 |
"limitations": limitations,
|
| 84 |
"assessment_summary": assessment_summary,
|
| 85 |
"recommendation": self._get_recommendation(reliability),
|
| 86 |
-
"alternative_diagnoses": self._get_alternatives(probabilities)
|
| 87 |
-
}
|
| 88 |
)
|
| 89 |
|
| 90 |
# Update state
|
|
@@ -93,20 +77,17 @@ class ConfidenceAssessorAgent:
|
|
| 93 |
print(f" - Evidence strength: {evidence_strength}")
|
| 94 |
print(f" - Limitations identified: {len(limitations)}")
|
| 95 |
|
| 96 |
-
return {
|
| 97 |
|
| 98 |
def _get_agent_findings(self, state: GuildState, agent_name: str) -> dict:
|
| 99 |
"""Extract findings from a specific agent"""
|
| 100 |
-
for output in state.get(
|
| 101 |
if output.agent_name == agent_name:
|
| 102 |
return output.findings
|
| 103 |
return {}
|
| 104 |
|
| 105 |
def _evaluate_evidence_strength(
|
| 106 |
-
self,
|
| 107 |
-
biomarker_analysis: dict,
|
| 108 |
-
disease_explanation: dict,
|
| 109 |
-
linker_findings: dict
|
| 110 |
) -> str:
|
| 111 |
"""Evaluate the strength of supporting evidence"""
|
| 112 |
|
|
@@ -114,19 +95,19 @@ class ConfidenceAssessorAgent:
|
|
| 114 |
max_score = 5
|
| 115 |
|
| 116 |
# Check biomarker validation quality
|
| 117 |
-
flags = biomarker_analysis.get(
|
| 118 |
-
abnormal_count = len([f for f in flags if f.get(
|
| 119 |
if abnormal_count >= 3:
|
| 120 |
score += 1
|
| 121 |
if abnormal_count >= 5:
|
| 122 |
score += 1
|
| 123 |
|
| 124 |
# Check disease explanation quality
|
| 125 |
-
if disease_explanation.get(
|
| 126 |
score += 1
|
| 127 |
|
| 128 |
# Check biomarker-disease linking
|
| 129 |
-
key_drivers = linker_findings.get(
|
| 130 |
if len(key_drivers) >= 2:
|
| 131 |
score += 1
|
| 132 |
if len(key_drivers) >= 4:
|
|
@@ -141,10 +122,7 @@ class ConfidenceAssessorAgent:
|
|
| 141 |
return "WEAK"
|
| 142 |
|
| 143 |
def _identify_limitations(
|
| 144 |
-
self,
|
| 145 |
-
biomarkers: dict[str, float],
|
| 146 |
-
biomarker_analysis: dict,
|
| 147 |
-
probabilities: dict[str, float]
|
| 148 |
) -> list[str]:
|
| 149 |
"""Identify limitations and uncertainties"""
|
| 150 |
limitations = []
|
|
@@ -161,37 +139,23 @@ class ConfidenceAssessorAgent:
|
|
| 161 |
top1, prob1 = sorted_probs[0]
|
| 162 |
top2, prob2 = sorted_probs[1]
|
| 163 |
if prob2 > 0.15: # Alternative is significant
|
| 164 |
-
limitations.append(
|
| 165 |
-
f"Differential diagnosis: {top2} also possible ({prob2:.1%} probability)"
|
| 166 |
-
)
|
| 167 |
|
| 168 |
# Check for normal biomarkers despite prediction
|
| 169 |
-
flags = biomarker_analysis.get(
|
| 170 |
-
relevant = biomarker_analysis.get(
|
| 171 |
-
normal_relevant = [
|
| 172 |
-
f for f in flags
|
| 173 |
-
if f.get('name') in relevant and f.get('status') == 'NORMAL'
|
| 174 |
-
]
|
| 175 |
if len(normal_relevant) >= 2:
|
| 176 |
-
limitations.append(
|
| 177 |
-
"Some disease-relevant biomarkers are within normal range"
|
| 178 |
-
)
|
| 179 |
|
| 180 |
# Check for safety alerts (indicates complexity)
|
| 181 |
-
alerts = biomarker_analysis.get(
|
| 182 |
if len(alerts) >= 2:
|
| 183 |
-
limitations.append(
|
| 184 |
-
"Multiple critical values detected; professional evaluation essential"
|
| 185 |
-
)
|
| 186 |
|
| 187 |
return limitations
|
| 188 |
|
| 189 |
-
def _calculate_reliability(
|
| 190 |
-
self,
|
| 191 |
-
ml_confidence: float,
|
| 192 |
-
evidence_strength: str,
|
| 193 |
-
limitation_count: int
|
| 194 |
-
) -> str:
|
| 195 |
"""Calculate overall prediction reliability"""
|
| 196 |
|
| 197 |
score = 0
|
|
@@ -224,12 +188,7 @@ class ConfidenceAssessorAgent:
|
|
| 224 |
return "LOW"
|
| 225 |
|
| 226 |
def _generate_assessment(
|
| 227 |
-
self,
|
| 228 |
-
disease: str,
|
| 229 |
-
ml_confidence: float,
|
| 230 |
-
reliability: str,
|
| 231 |
-
evidence_strength: str,
|
| 232 |
-
limitations: list[str]
|
| 233 |
) -> str:
|
| 234 |
"""Generate human-readable assessment summary"""
|
| 235 |
|
|
@@ -271,11 +230,9 @@ Be honest about uncertainty. Patient safety is paramount."""
|
|
| 271 |
alternatives = []
|
| 272 |
for disease, prob in sorted_probs[1:4]: # Top 3 alternatives
|
| 273 |
if prob > 0.05: # Only significant alternatives
|
| 274 |
-
alternatives.append(
|
| 275 |
-
"disease": disease,
|
| 276 |
-
|
| 277 |
-
"note": "Consider discussing with healthcare provider"
|
| 278 |
-
})
|
| 279 |
|
| 280 |
return alternatives
|
| 281 |
|
|
|
|
| 19 |
def assess(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Assess prediction confidence and identify limitations.
|
| 22 |
+
|
| 23 |
Args:
|
| 24 |
state: Current guild state
|
| 25 |
+
|
| 26 |
Returns:
|
| 27 |
Updated state with confidence assessment
|
| 28 |
"""
|
| 29 |
+
print("\n" + "=" * 70)
|
| 30 |
print("EXECUTING: Confidence Assessor Agent")
|
| 31 |
+
print("=" * 70)
|
| 32 |
|
| 33 |
+
model_prediction = state["model_prediction"]
|
| 34 |
+
disease = model_prediction["disease"]
|
| 35 |
+
ml_confidence = model_prediction["confidence"]
|
| 36 |
+
probabilities = model_prediction.get("probabilities", {})
|
| 37 |
+
biomarkers = state["patient_biomarkers"]
|
| 38 |
|
| 39 |
# Collect previous agent findings
|
| 40 |
+
biomarker_analysis = state.get("biomarker_analysis") or {}
|
| 41 |
disease_explanation = self._get_agent_findings(state, "Disease Explainer")
|
| 42 |
linker_findings = self._get_agent_findings(state, "Biomarker-Disease Linker")
|
| 43 |
|
| 44 |
print(f"\nAssessing confidence for {disease} prediction...")
|
| 45 |
|
| 46 |
# Evaluate evidence strength
|
| 47 |
+
evidence_strength = self._evaluate_evidence_strength(biomarker_analysis, disease_explanation, linker_findings)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# Identify limitations
|
| 50 |
+
limitations = self._identify_limitations(biomarkers, biomarker_analysis, probabilities)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Calculate aggregate reliability
|
| 53 |
+
reliability = self._calculate_reliability(ml_confidence, evidence_strength, len(limitations))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
# Generate assessment summary
|
| 56 |
assessment_summary = self._generate_assessment(
|
| 57 |
+
disease, ml_confidence, reliability, evidence_strength, limitations
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
# Create agent output
|
|
|
|
| 67 |
"limitations": limitations,
|
| 68 |
"assessment_summary": assessment_summary,
|
| 69 |
"recommendation": self._get_recommendation(reliability),
|
| 70 |
+
"alternative_diagnoses": self._get_alternatives(probabilities),
|
| 71 |
+
},
|
| 72 |
)
|
| 73 |
|
| 74 |
# Update state
|
|
|
|
| 77 |
print(f" - Evidence strength: {evidence_strength}")
|
| 78 |
print(f" - Limitations identified: {len(limitations)}")
|
| 79 |
|
| 80 |
+
return {"agent_outputs": [output]}
|
| 81 |
|
| 82 |
def _get_agent_findings(self, state: GuildState, agent_name: str) -> dict:
|
| 83 |
"""Extract findings from a specific agent"""
|
| 84 |
+
for output in state.get("agent_outputs", []):
|
| 85 |
if output.agent_name == agent_name:
|
| 86 |
return output.findings
|
| 87 |
return {}
|
| 88 |
|
| 89 |
def _evaluate_evidence_strength(
|
| 90 |
+
self, biomarker_analysis: dict, disease_explanation: dict, linker_findings: dict
|
|
|
|
|
|
|
|
|
|
| 91 |
) -> str:
|
| 92 |
"""Evaluate the strength of supporting evidence"""
|
| 93 |
|
|
|
|
| 95 |
max_score = 5
|
| 96 |
|
| 97 |
# Check biomarker validation quality
|
| 98 |
+
flags = biomarker_analysis.get("biomarker_flags", [])
|
| 99 |
+
abnormal_count = len([f for f in flags if f.get("status") != "NORMAL"])
|
| 100 |
if abnormal_count >= 3:
|
| 101 |
score += 1
|
| 102 |
if abnormal_count >= 5:
|
| 103 |
score += 1
|
| 104 |
|
| 105 |
# Check disease explanation quality
|
| 106 |
+
if disease_explanation.get("retrieval_quality", 0) >= 3:
|
| 107 |
score += 1
|
| 108 |
|
| 109 |
# Check biomarker-disease linking
|
| 110 |
+
key_drivers = linker_findings.get("key_drivers", [])
|
| 111 |
if len(key_drivers) >= 2:
|
| 112 |
score += 1
|
| 113 |
if len(key_drivers) >= 4:
|
|
|
|
| 122 |
return "WEAK"
|
| 123 |
|
| 124 |
def _identify_limitations(
|
| 125 |
+
self, biomarkers: dict[str, float], biomarker_analysis: dict, probabilities: dict[str, float]
|
|
|
|
|
|
|
|
|
|
| 126 |
) -> list[str]:
|
| 127 |
"""Identify limitations and uncertainties"""
|
| 128 |
limitations = []
|
|
|
|
| 139 |
top1, prob1 = sorted_probs[0]
|
| 140 |
top2, prob2 = sorted_probs[1]
|
| 141 |
if prob2 > 0.15: # Alternative is significant
|
| 142 |
+
limitations.append(f"Differential diagnosis: {top2} also possible ({prob2:.1%} probability)")
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# Check for normal biomarkers despite prediction
|
| 145 |
+
flags = biomarker_analysis.get("biomarker_flags", [])
|
| 146 |
+
relevant = biomarker_analysis.get("relevant_biomarkers", [])
|
| 147 |
+
normal_relevant = [f for f in flags if f.get("name") in relevant and f.get("status") == "NORMAL"]
|
|
|
|
|
|
|
|
|
|
| 148 |
if len(normal_relevant) >= 2:
|
| 149 |
+
limitations.append("Some disease-relevant biomarkers are within normal range")
|
|
|
|
|
|
|
| 150 |
|
| 151 |
# Check for safety alerts (indicates complexity)
|
| 152 |
+
alerts = biomarker_analysis.get("safety_alerts", [])
|
| 153 |
if len(alerts) >= 2:
|
| 154 |
+
limitations.append("Multiple critical values detected; professional evaluation essential")
|
|
|
|
|
|
|
| 155 |
|
| 156 |
return limitations
|
| 157 |
|
| 158 |
+
def _calculate_reliability(self, ml_confidence: float, evidence_strength: str, limitation_count: int) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
"""Calculate overall prediction reliability"""
|
| 160 |
|
| 161 |
score = 0
|
|
|
|
| 188 |
return "LOW"
|
| 189 |
|
| 190 |
def _generate_assessment(
|
| 191 |
+
self, disease: str, ml_confidence: float, reliability: str, evidence_strength: str, limitations: list[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
) -> str:
|
| 193 |
"""Generate human-readable assessment summary"""
|
| 194 |
|
|
|
|
| 230 |
alternatives = []
|
| 231 |
for disease, prob in sorted_probs[1:4]: # Top 3 alternatives
|
| 232 |
if prob > 0.05: # Only significant alternatives
|
| 233 |
+
alternatives.append(
|
| 234 |
+
{"disease": disease, "probability": prob, "note": "Consider discussing with healthcare provider"}
|
| 235 |
+
)
|
|
|
|
|
|
|
| 236 |
|
| 237 |
return alternatives
|
| 238 |
|
src/agents/disease_explainer.py
CHANGED
|
@@ -17,7 +17,7 @@ class DiseaseExplainerAgent:
|
|
| 17 |
def __init__(self, retriever):
|
| 18 |
"""
|
| 19 |
Initialize with a retriever for medical PDFs.
|
| 20 |
-
|
| 21 |
Args:
|
| 22 |
retriever: Vector store retriever for disease documents
|
| 23 |
"""
|
|
@@ -27,25 +27,25 @@ class DiseaseExplainerAgent:
|
|
| 27 |
def explain(self, state: GuildState) -> GuildState:
|
| 28 |
"""
|
| 29 |
Retrieve and explain disease pathophysiology.
|
| 30 |
-
|
| 31 |
Args:
|
| 32 |
state: Current guild state
|
| 33 |
-
|
| 34 |
Returns:
|
| 35 |
Updated state with disease explanation
|
| 36 |
"""
|
| 37 |
-
print("\n" + "="*70)
|
| 38 |
print("EXECUTING: Disease Explainer Agent (RAG)")
|
| 39 |
-
print("="*70)
|
| 40 |
|
| 41 |
-
model_prediction = state[
|
| 42 |
-
disease = model_prediction[
|
| 43 |
-
confidence = model_prediction[
|
| 44 |
|
| 45 |
# Configure retrieval based on SOP β create a copy to avoid mutating shared retriever
|
| 46 |
-
retrieval_k = state[
|
| 47 |
original_search_kwargs = dict(self.retriever.search_kwargs)
|
| 48 |
-
self.retriever.search_kwargs = {**original_search_kwargs,
|
| 49 |
|
| 50 |
# Retrieve relevant documents
|
| 51 |
print(f"\nRetrieving information about: {disease}")
|
|
@@ -62,33 +62,33 @@ class DiseaseExplainerAgent:
|
|
| 62 |
|
| 63 |
print(f"Retrieved {len(docs)} relevant document chunks")
|
| 64 |
|
| 65 |
-
if state[
|
| 66 |
explanation = {
|
| 67 |
"pathophysiology": "Insufficient evidence available in the knowledge base to explain this condition.",
|
| 68 |
"diagnostic_criteria": "Insufficient evidence available to list diagnostic criteria.",
|
| 69 |
"clinical_presentation": "Insufficient evidence available to describe clinical presentation.",
|
| 70 |
-
"summary": "Insufficient evidence available for a detailed explanation."
|
| 71 |
}
|
| 72 |
citations = []
|
| 73 |
output = AgentOutput(
|
| 74 |
agent_name="Disease Explainer",
|
| 75 |
findings={
|
| 76 |
"disease": disease,
|
| 77 |
-
"pathophysiology": explanation[
|
| 78 |
-
"diagnostic_criteria": explanation[
|
| 79 |
-
"clinical_presentation": explanation[
|
| 80 |
-
"mechanism_summary": explanation[
|
| 81 |
"citations": citations,
|
| 82 |
"confidence": confidence,
|
| 83 |
"retrieval_quality": 0,
|
| 84 |
-
"citations_missing": True
|
| 85 |
-
}
|
| 86 |
)
|
| 87 |
|
| 88 |
print("\nDisease explanation generated")
|
| 89 |
print(" - Pathophysiology: insufficient evidence")
|
| 90 |
print(" - Citations: 0 sources")
|
| 91 |
-
return {
|
| 92 |
|
| 93 |
# Generate explanation
|
| 94 |
explanation = self._generate_explanation(disease, docs, confidence)
|
|
@@ -101,15 +101,15 @@ class DiseaseExplainerAgent:
|
|
| 101 |
agent_name="Disease Explainer",
|
| 102 |
findings={
|
| 103 |
"disease": disease,
|
| 104 |
-
"pathophysiology": explanation[
|
| 105 |
-
"diagnostic_criteria": explanation[
|
| 106 |
-
"clinical_presentation": explanation[
|
| 107 |
-
"mechanism_summary": explanation[
|
| 108 |
"citations": citations,
|
| 109 |
"confidence": confidence,
|
| 110 |
"retrieval_quality": len(docs),
|
| 111 |
-
"citations_missing": False
|
| 112 |
-
}
|
| 113 |
)
|
| 114 |
|
| 115 |
# Update state
|
|
@@ -117,19 +117,21 @@ class DiseaseExplainerAgent:
|
|
| 117 |
print(f" - Pathophysiology: {len(explanation['pathophysiology'])} chars")
|
| 118 |
print(f" - Citations: {len(citations)} sources")
|
| 119 |
|
| 120 |
-
return {
|
| 121 |
|
| 122 |
def _generate_explanation(self, disease: str, docs: list, confidence: float) -> dict:
|
| 123 |
"""Generate structured disease explanation using LLM and retrieved docs"""
|
| 124 |
|
| 125 |
# Format retrieved context
|
| 126 |
-
context = "\n\n---\n\n".join(
|
| 127 |
-
f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
|
| 128 |
-
|
| 129 |
-
])
|
| 130 |
|
| 131 |
-
prompt = ChatPromptTemplate.from_messages(
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
| 133 |
Based on the provided medical literature, explain the disease in clear, accessible language.
|
| 134 |
Structure your response with these sections:
|
| 135 |
1. PATHOPHYSIOLOGY: The underlying biological mechanisms
|
|
@@ -137,24 +139,25 @@ class DiseaseExplainerAgent:
|
|
| 137 |
3. CLINICAL_PRESENTATION: Common symptoms and signs
|
| 138 |
4. SUMMARY: A 2-3 sentence overview
|
| 139 |
|
| 140 |
-
Be accurate, cite-able, and patient-friendly. Focus on how the disease affects blood biomarkers."""
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
| 142 |
Prediction Confidence: {confidence:.1%}
|
| 143 |
|
| 144 |
Medical Literature Context:
|
| 145 |
{context}
|
| 146 |
|
| 147 |
-
Please provide a structured explanation."""
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
|
| 150 |
chain = prompt | self.llm
|
| 151 |
|
| 152 |
try:
|
| 153 |
-
response = chain.invoke({
|
| 154 |
-
"disease": disease,
|
| 155 |
-
"confidence": confidence,
|
| 156 |
-
"context": context
|
| 157 |
-
})
|
| 158 |
|
| 159 |
# Parse structured response
|
| 160 |
content = response.content
|
|
@@ -166,41 +169,36 @@ class DiseaseExplainerAgent:
|
|
| 166 |
"pathophysiology": f"{disease} is a medical condition requiring professional diagnosis.",
|
| 167 |
"diagnostic_criteria": "Consult medical guidelines for diagnostic criteria.",
|
| 168 |
"clinical_presentation": "Clinical presentation varies by individual.",
|
| 169 |
-
"summary": f"{disease} detected with {confidence:.1%} confidence. Consult healthcare provider."
|
| 170 |
}
|
| 171 |
|
| 172 |
return explanation
|
| 173 |
|
| 174 |
def _parse_explanation(self, content: str) -> dict:
|
| 175 |
"""Parse LLM response into structured sections"""
|
| 176 |
-
sections = {
|
| 177 |
-
"pathophysiology": "",
|
| 178 |
-
"diagnostic_criteria": "",
|
| 179 |
-
"clinical_presentation": "",
|
| 180 |
-
"summary": ""
|
| 181 |
-
}
|
| 182 |
|
| 183 |
# Simple parsing logic
|
| 184 |
current_section = None
|
| 185 |
-
lines = content.split(
|
| 186 |
|
| 187 |
for line in lines:
|
| 188 |
line_upper = line.upper().strip()
|
| 189 |
|
| 190 |
-
if
|
| 191 |
-
current_section =
|
| 192 |
-
elif
|
| 193 |
-
current_section =
|
| 194 |
-
elif
|
| 195 |
-
current_section =
|
| 196 |
-
elif
|
| 197 |
-
current_section =
|
| 198 |
elif current_section and line.strip():
|
| 199 |
sections[current_section] += line + "\n"
|
| 200 |
|
| 201 |
# If parsing failed, use full content as summary
|
| 202 |
if not any(sections.values()):
|
| 203 |
-
sections[
|
| 204 |
|
| 205 |
return sections
|
| 206 |
|
|
@@ -209,15 +207,15 @@ class DiseaseExplainerAgent:
|
|
| 209 |
citations = []
|
| 210 |
|
| 211 |
for doc in docs:
|
| 212 |
-
source = doc.metadata.get(
|
| 213 |
-
page = doc.metadata.get(
|
| 214 |
|
| 215 |
# Clean up source path
|
| 216 |
-
if
|
| 217 |
source = Path(source).name
|
| 218 |
|
| 219 |
citation = f"{source}"
|
| 220 |
-
if page !=
|
| 221 |
citation += f" (Page {page})"
|
| 222 |
|
| 223 |
citations.append(citation)
|
|
|
|
| 17 |
def __init__(self, retriever):
|
| 18 |
"""
|
| 19 |
Initialize with a retriever for medical PDFs.
|
| 20 |
+
|
| 21 |
Args:
|
| 22 |
retriever: Vector store retriever for disease documents
|
| 23 |
"""
|
|
|
|
| 27 |
def explain(self, state: GuildState) -> GuildState:
|
| 28 |
"""
|
| 29 |
Retrieve and explain disease pathophysiology.
|
| 30 |
+
|
| 31 |
Args:
|
| 32 |
state: Current guild state
|
| 33 |
+
|
| 34 |
Returns:
|
| 35 |
Updated state with disease explanation
|
| 36 |
"""
|
| 37 |
+
print("\n" + "=" * 70)
|
| 38 |
print("EXECUTING: Disease Explainer Agent (RAG)")
|
| 39 |
+
print("=" * 70)
|
| 40 |
|
| 41 |
+
model_prediction = state["model_prediction"]
|
| 42 |
+
disease = model_prediction["disease"]
|
| 43 |
+
confidence = model_prediction["confidence"]
|
| 44 |
|
| 45 |
# Configure retrieval based on SOP β create a copy to avoid mutating shared retriever
|
| 46 |
+
retrieval_k = state["sop"].disease_explainer_k
|
| 47 |
original_search_kwargs = dict(self.retriever.search_kwargs)
|
| 48 |
+
self.retriever.search_kwargs = {**original_search_kwargs, "k": retrieval_k}
|
| 49 |
|
| 50 |
# Retrieve relevant documents
|
| 51 |
print(f"\nRetrieving information about: {disease}")
|
|
|
|
| 62 |
|
| 63 |
print(f"Retrieved {len(docs)} relevant document chunks")
|
| 64 |
|
| 65 |
+
if state["sop"].require_pdf_citations and not docs:
|
| 66 |
explanation = {
|
| 67 |
"pathophysiology": "Insufficient evidence available in the knowledge base to explain this condition.",
|
| 68 |
"diagnostic_criteria": "Insufficient evidence available to list diagnostic criteria.",
|
| 69 |
"clinical_presentation": "Insufficient evidence available to describe clinical presentation.",
|
| 70 |
+
"summary": "Insufficient evidence available for a detailed explanation.",
|
| 71 |
}
|
| 72 |
citations = []
|
| 73 |
output = AgentOutput(
|
| 74 |
agent_name="Disease Explainer",
|
| 75 |
findings={
|
| 76 |
"disease": disease,
|
| 77 |
+
"pathophysiology": explanation["pathophysiology"],
|
| 78 |
+
"diagnostic_criteria": explanation["diagnostic_criteria"],
|
| 79 |
+
"clinical_presentation": explanation["clinical_presentation"],
|
| 80 |
+
"mechanism_summary": explanation["summary"],
|
| 81 |
"citations": citations,
|
| 82 |
"confidence": confidence,
|
| 83 |
"retrieval_quality": 0,
|
| 84 |
+
"citations_missing": True,
|
| 85 |
+
},
|
| 86 |
)
|
| 87 |
|
| 88 |
print("\nDisease explanation generated")
|
| 89 |
print(" - Pathophysiology: insufficient evidence")
|
| 90 |
print(" - Citations: 0 sources")
|
| 91 |
+
return {"agent_outputs": [output]}
|
| 92 |
|
| 93 |
# Generate explanation
|
| 94 |
explanation = self._generate_explanation(disease, docs, confidence)
|
|
|
|
| 101 |
agent_name="Disease Explainer",
|
| 102 |
findings={
|
| 103 |
"disease": disease,
|
| 104 |
+
"pathophysiology": explanation["pathophysiology"],
|
| 105 |
+
"diagnostic_criteria": explanation["diagnostic_criteria"],
|
| 106 |
+
"clinical_presentation": explanation["clinical_presentation"],
|
| 107 |
+
"mechanism_summary": explanation["summary"],
|
| 108 |
"citations": citations,
|
| 109 |
"confidence": confidence,
|
| 110 |
"retrieval_quality": len(docs),
|
| 111 |
+
"citations_missing": False,
|
| 112 |
+
},
|
| 113 |
)
|
| 114 |
|
| 115 |
# Update state
|
|
|
|
| 117 |
print(f" - Pathophysiology: {len(explanation['pathophysiology'])} chars")
|
| 118 |
print(f" - Citations: {len(citations)} sources")
|
| 119 |
|
| 120 |
+
return {"agent_outputs": [output]}
|
| 121 |
|
| 122 |
def _generate_explanation(self, disease: str, docs: list, confidence: float) -> dict:
|
| 123 |
"""Generate structured disease explanation using LLM and retrieved docs"""
|
| 124 |
|
| 125 |
# Format retrieved context
|
| 126 |
+
context = "\n\n---\n\n".join(
|
| 127 |
+
[f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}" for doc in docs]
|
| 128 |
+
)
|
|
|
|
| 129 |
|
| 130 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 131 |
+
[
|
| 132 |
+
(
|
| 133 |
+
"system",
|
| 134 |
+
"""You are a medical expert explaining diseases for patient self-assessment.
|
| 135 |
Based on the provided medical literature, explain the disease in clear, accessible language.
|
| 136 |
Structure your response with these sections:
|
| 137 |
1. PATHOPHYSIOLOGY: The underlying biological mechanisms
|
|
|
|
| 139 |
3. CLINICAL_PRESENTATION: Common symptoms and signs
|
| 140 |
4. SUMMARY: A 2-3 sentence overview
|
| 141 |
|
| 142 |
+
Be accurate, cite-able, and patient-friendly. Focus on how the disease affects blood biomarkers.""",
|
| 143 |
+
),
|
| 144 |
+
(
|
| 145 |
+
"human",
|
| 146 |
+
"""Disease: {disease}
|
| 147 |
Prediction Confidence: {confidence:.1%}
|
| 148 |
|
| 149 |
Medical Literature Context:
|
| 150 |
{context}
|
| 151 |
|
| 152 |
+
Please provide a structured explanation.""",
|
| 153 |
+
),
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
|
| 157 |
chain = prompt | self.llm
|
| 158 |
|
| 159 |
try:
|
| 160 |
+
response = chain.invoke({"disease": disease, "confidence": confidence, "context": context})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Parse structured response
|
| 163 |
content = response.content
|
|
|
|
| 169 |
"pathophysiology": f"{disease} is a medical condition requiring professional diagnosis.",
|
| 170 |
"diagnostic_criteria": "Consult medical guidelines for diagnostic criteria.",
|
| 171 |
"clinical_presentation": "Clinical presentation varies by individual.",
|
| 172 |
+
"summary": f"{disease} detected with {confidence:.1%} confidence. Consult healthcare provider.",
|
| 173 |
}
|
| 174 |
|
| 175 |
return explanation
|
| 176 |
|
| 177 |
def _parse_explanation(self, content: str) -> dict:
|
| 178 |
"""Parse LLM response into structured sections"""
|
| 179 |
+
sections = {"pathophysiology": "", "diagnostic_criteria": "", "clinical_presentation": "", "summary": ""}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Simple parsing logic
|
| 182 |
current_section = None
|
| 183 |
+
lines = content.split("\n")
|
| 184 |
|
| 185 |
for line in lines:
|
| 186 |
line_upper = line.upper().strip()
|
| 187 |
|
| 188 |
+
if "PATHOPHYSIOLOGY" in line_upper:
|
| 189 |
+
current_section = "pathophysiology"
|
| 190 |
+
elif "DIAGNOSTIC" in line_upper:
|
| 191 |
+
current_section = "diagnostic_criteria"
|
| 192 |
+
elif "CLINICAL" in line_upper or "PRESENTATION" in line_upper:
|
| 193 |
+
current_section = "clinical_presentation"
|
| 194 |
+
elif "SUMMARY" in line_upper:
|
| 195 |
+
current_section = "summary"
|
| 196 |
elif current_section and line.strip():
|
| 197 |
sections[current_section] += line + "\n"
|
| 198 |
|
| 199 |
# If parsing failed, use full content as summary
|
| 200 |
if not any(sections.values()):
|
| 201 |
+
sections["summary"] = content[:500]
|
| 202 |
|
| 203 |
return sections
|
| 204 |
|
|
|
|
| 207 |
citations = []
|
| 208 |
|
| 209 |
for doc in docs:
|
| 210 |
+
source = doc.metadata.get("source", "Unknown")
|
| 211 |
+
page = doc.metadata.get("page", "N/A")
|
| 212 |
|
| 213 |
# Clean up source path
|
| 214 |
+
if "\\" in source or "/" in source:
|
| 215 |
source = Path(source).name
|
| 216 |
|
| 217 |
citation = f"{source}"
|
| 218 |
+
if page != "N/A":
|
| 219 |
citation += f" (Page {page})"
|
| 220 |
|
| 221 |
citations.append(citation)
|
src/agents/response_synthesizer.py
CHANGED
|
@@ -20,21 +20,21 @@ class ResponseSynthesizerAgent:
|
|
| 20 |
def synthesize(self, state: GuildState) -> GuildState:
|
| 21 |
"""
|
| 22 |
Synthesize all agent outputs into final response.
|
| 23 |
-
|
| 24 |
Args:
|
| 25 |
state: Complete guild state with all agent outputs
|
| 26 |
-
|
| 27 |
Returns:
|
| 28 |
Updated state with final_response
|
| 29 |
"""
|
| 30 |
-
print("\n" + "="*70)
|
| 31 |
print("EXECUTING: Response Synthesizer Agent")
|
| 32 |
-
print("="*70)
|
| 33 |
|
| 34 |
-
model_prediction = state[
|
| 35 |
-
patient_biomarkers = state[
|
| 36 |
-
patient_context = state.get(
|
| 37 |
-
agent_outputs = state.get(
|
| 38 |
|
| 39 |
# Collect findings from all agents
|
| 40 |
findings = self._collect_findings(agent_outputs)
|
|
@@ -62,24 +62,24 @@ class ResponseSynthesizerAgent:
|
|
| 62 |
"disease_explanation": self._build_disease_explanation(findings),
|
| 63 |
"recommendations": recs,
|
| 64 |
"confidence_assessment": self._build_confidence_assessment(findings),
|
| 65 |
-
"alternative_diagnoses": self._build_alternative_diagnoses(findings)
|
| 66 |
-
}
|
| 67 |
}
|
| 68 |
|
| 69 |
# Generate patient-friendly summary
|
| 70 |
response["patient_summary"]["narrative"] = self._generate_narrative_summary(
|
| 71 |
-
model_prediction,
|
| 72 |
-
findings,
|
| 73 |
-
response
|
| 74 |
)
|
| 75 |
|
| 76 |
print("\nResponse synthesis complete")
|
| 77 |
print(" - Patient summary: Generated")
|
| 78 |
print(f" - Prediction explanation: {len(response['prediction_explanation']['key_drivers'])} key drivers")
|
| 79 |
-
print(
|
|
|
|
|
|
|
| 80 |
print(f" - Safety alerts: {len(response['safety_alerts'])} alerts")
|
| 81 |
|
| 82 |
-
return {
|
| 83 |
|
| 84 |
def _collect_findings(self, agent_outputs: list) -> dict[str, Any]:
|
| 85 |
"""Organize all agent findings by agent name"""
|
|
@@ -91,19 +91,19 @@ class ResponseSynthesizerAgent:
|
|
| 91 |
def _build_patient_summary(self, biomarkers: dict, findings: dict) -> dict:
|
| 92 |
"""Build patient summary section"""
|
| 93 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 94 |
-
flags = biomarker_analysis.get(
|
| 95 |
|
| 96 |
# Count biomarker statuses
|
| 97 |
-
critical = len([f for f in flags if
|
| 98 |
-
abnormal = len([f for f in flags if f.get(
|
| 99 |
|
| 100 |
return {
|
| 101 |
"total_biomarkers_tested": len(biomarkers),
|
| 102 |
"biomarkers_in_normal_range": len(flags) - abnormal,
|
| 103 |
"biomarkers_out_of_range": abnormal,
|
| 104 |
"critical_values": critical,
|
| 105 |
-
"overall_risk_profile": biomarker_analysis.get(
|
| 106 |
-
"narrative": "" # Will be filled later
|
| 107 |
}
|
| 108 |
|
| 109 |
def _build_prediction_explanation(self, model_prediction: dict, findings: dict) -> dict:
|
|
@@ -111,18 +111,18 @@ class ResponseSynthesizerAgent:
|
|
| 111 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 112 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 113 |
|
| 114 |
-
disease = model_prediction[
|
| 115 |
-
confidence = model_prediction[
|
| 116 |
|
| 117 |
# Get key drivers
|
| 118 |
-
key_drivers_raw = linker_findings.get(
|
| 119 |
key_drivers = [
|
| 120 |
{
|
| 121 |
-
"biomarker": kd.get(
|
| 122 |
-
"value": kd.get(
|
| 123 |
-
"contribution": kd.get(
|
| 124 |
-
"explanation": kd.get(
|
| 125 |
-
"evidence": kd.get(
|
| 126 |
}
|
| 127 |
for kd in key_drivers_raw
|
| 128 |
]
|
|
@@ -131,25 +131,25 @@ class ResponseSynthesizerAgent:
|
|
| 131 |
"primary_disease": disease,
|
| 132 |
"confidence": confidence,
|
| 133 |
"key_drivers": key_drivers,
|
| 134 |
-
"mechanism_summary": disease_explanation.get(
|
| 135 |
-
"pathophysiology": disease_explanation.get(
|
| 136 |
-
"pdf_references": disease_explanation.get(
|
| 137 |
}
|
| 138 |
|
| 139 |
def _build_biomarker_flags(self, findings: dict) -> list[dict]:
|
| 140 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 141 |
-
return biomarker_analysis.get(
|
| 142 |
|
| 143 |
def _build_key_drivers(self, findings: dict) -> list[dict]:
|
| 144 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 145 |
-
return linker_findings.get(
|
| 146 |
|
| 147 |
def _build_disease_explanation(self, findings: dict) -> dict:
|
| 148 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 149 |
return {
|
| 150 |
-
"pathophysiology": disease_explanation.get(
|
| 151 |
-
"citations": disease_explanation.get(
|
| 152 |
-
"retrieved_chunks": disease_explanation.get(
|
| 153 |
}
|
| 154 |
|
| 155 |
def _build_recommendations(self, findings: dict) -> dict:
|
|
@@ -157,10 +157,10 @@ class ResponseSynthesizerAgent:
|
|
| 157 |
guidelines = findings.get("Clinical Guidelines", {})
|
| 158 |
|
| 159 |
return {
|
| 160 |
-
"immediate_actions": guidelines.get(
|
| 161 |
-
"lifestyle_changes": guidelines.get(
|
| 162 |
-
"monitoring": guidelines.get(
|
| 163 |
-
"guideline_citations": guidelines.get(
|
| 164 |
}
|
| 165 |
|
| 166 |
def _build_confidence_assessment(self, findings: dict) -> dict:
|
|
@@ -168,22 +168,22 @@ class ResponseSynthesizerAgent:
|
|
| 168 |
assessment = findings.get("Confidence Assessor", {})
|
| 169 |
|
| 170 |
return {
|
| 171 |
-
"prediction_reliability": assessment.get(
|
| 172 |
-
"evidence_strength": assessment.get(
|
| 173 |
-
"limitations": assessment.get(
|
| 174 |
-
"recommendation": assessment.get(
|
| 175 |
-
"assessment_summary": assessment.get(
|
| 176 |
-
"alternative_diagnoses": assessment.get(
|
| 177 |
}
|
| 178 |
|
| 179 |
def _build_alternative_diagnoses(self, findings: dict) -> list[dict]:
|
| 180 |
assessment = findings.get("Confidence Assessor", {})
|
| 181 |
-
return assessment.get(
|
| 182 |
|
| 183 |
def _build_safety_alerts(self, findings: dict) -> list[dict]:
|
| 184 |
"""Build safety alerts section"""
|
| 185 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 186 |
-
return biomarker_analysis.get(
|
| 187 |
|
| 188 |
def _build_metadata(self, state: GuildState) -> dict:
|
| 189 |
"""Build metadata section"""
|
|
@@ -193,59 +193,64 @@ class ResponseSynthesizerAgent:
|
|
| 193 |
"timestamp": datetime.now().isoformat(),
|
| 194 |
"system_version": "MediGuard AI RAG-Helper v1.0",
|
| 195 |
"sop_version": "Baseline",
|
| 196 |
-
"agents_executed": [output.agent_name for output in state.get(
|
| 197 |
-
"disclaimer": "This is an AI-assisted analysis tool for patient self-assessment. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always consult qualified healthcare providers for medical decisions."
|
| 198 |
}
|
| 199 |
|
| 200 |
-
def _generate_narrative_summary(
|
| 201 |
-
self,
|
| 202 |
-
model_prediction,
|
| 203 |
-
findings: dict,
|
| 204 |
-
response: dict
|
| 205 |
-
) -> str:
|
| 206 |
"""Generate a patient-friendly narrative summary using LLM"""
|
| 207 |
|
| 208 |
-
disease = model_prediction[
|
| 209 |
-
confidence = model_prediction[
|
| 210 |
-
reliability = response[
|
| 211 |
|
| 212 |
# Get key points
|
| 213 |
-
critical_count = response[
|
| 214 |
-
abnormal_count = response[
|
| 215 |
-
key_drivers = response[
|
| 216 |
-
|
| 217 |
-
prompt = ChatPromptTemplate.from_messages(
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
| 219 |
Write a clear, compassionate 3-4 sentence summary that:
|
| 220 |
1. States the predicted condition and confidence level
|
| 221 |
2. Highlights the most important biomarker findings
|
| 222 |
3. Emphasizes the need for medical consultation
|
| 223 |
4. Offers reassurance while being honest about findings
|
| 224 |
|
| 225 |
-
Use patient-friendly language. Avoid medical jargon. Be supportive and clear."""
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
| 227 |
Model Confidence: {confidence:.1%}
|
| 228 |
Overall Reliability: {reliability}
|
| 229 |
Critical Values: {critical}
|
| 230 |
Out-of-Range Values: {abnormal}
|
| 231 |
Top Biomarker Drivers: {drivers}
|
| 232 |
|
| 233 |
-
Write a compassionate patient summary."""
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
|
| 236 |
chain = prompt | self.llm
|
| 237 |
|
| 238 |
try:
|
| 239 |
-
driver_names = [kd[
|
| 240 |
-
|
| 241 |
-
response_obj = chain.invoke(
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
|
|
|
| 249 |
|
| 250 |
return response_obj.content.strip()
|
| 251 |
|
|
|
|
| 20 |
def synthesize(self, state: GuildState) -> GuildState:
|
| 21 |
"""
|
| 22 |
Synthesize all agent outputs into final response.
|
| 23 |
+
|
| 24 |
Args:
|
| 25 |
state: Complete guild state with all agent outputs
|
| 26 |
+
|
| 27 |
Returns:
|
| 28 |
Updated state with final_response
|
| 29 |
"""
|
| 30 |
+
print("\n" + "=" * 70)
|
| 31 |
print("EXECUTING: Response Synthesizer Agent")
|
| 32 |
+
print("=" * 70)
|
| 33 |
|
| 34 |
+
model_prediction = state["model_prediction"]
|
| 35 |
+
patient_biomarkers = state["patient_biomarkers"]
|
| 36 |
+
patient_context = state.get("patient_context", {})
|
| 37 |
+
agent_outputs = state.get("agent_outputs", [])
|
| 38 |
|
| 39 |
# Collect findings from all agents
|
| 40 |
findings = self._collect_findings(agent_outputs)
|
|
|
|
| 62 |
"disease_explanation": self._build_disease_explanation(findings),
|
| 63 |
"recommendations": recs,
|
| 64 |
"confidence_assessment": self._build_confidence_assessment(findings),
|
| 65 |
+
"alternative_diagnoses": self._build_alternative_diagnoses(findings),
|
| 66 |
+
},
|
| 67 |
}
|
| 68 |
|
| 69 |
# Generate patient-friendly summary
|
| 70 |
response["patient_summary"]["narrative"] = self._generate_narrative_summary(
|
| 71 |
+
model_prediction, findings, response
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
print("\nResponse synthesis complete")
|
| 75 |
print(" - Patient summary: Generated")
|
| 76 |
print(f" - Prediction explanation: {len(response['prediction_explanation']['key_drivers'])} key drivers")
|
| 77 |
+
print(
|
| 78 |
+
f" - Recommendations: {len(response['clinical_recommendations']['immediate_actions'])} immediate actions"
|
| 79 |
+
)
|
| 80 |
print(f" - Safety alerts: {len(response['safety_alerts'])} alerts")
|
| 81 |
|
| 82 |
+
return {"final_response": response}
|
| 83 |
|
| 84 |
def _collect_findings(self, agent_outputs: list) -> dict[str, Any]:
|
| 85 |
"""Organize all agent findings by agent name"""
|
|
|
|
| 91 |
def _build_patient_summary(self, biomarkers: dict, findings: dict) -> dict:
|
| 92 |
"""Build patient summary section"""
|
| 93 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 94 |
+
flags = biomarker_analysis.get("biomarker_flags", [])
|
| 95 |
|
| 96 |
# Count biomarker statuses
|
| 97 |
+
critical = len([f for f in flags if "CRITICAL" in f.get("status", "")])
|
| 98 |
+
abnormal = len([f for f in flags if f.get("status") != "NORMAL"])
|
| 99 |
|
| 100 |
return {
|
| 101 |
"total_biomarkers_tested": len(biomarkers),
|
| 102 |
"biomarkers_in_normal_range": len(flags) - abnormal,
|
| 103 |
"biomarkers_out_of_range": abnormal,
|
| 104 |
"critical_values": critical,
|
| 105 |
+
"overall_risk_profile": biomarker_analysis.get("summary", "Assessment complete"),
|
| 106 |
+
"narrative": "", # Will be filled later
|
| 107 |
}
|
| 108 |
|
| 109 |
def _build_prediction_explanation(self, model_prediction: dict, findings: dict) -> dict:
|
|
|
|
| 111 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 112 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 113 |
|
| 114 |
+
disease = model_prediction["disease"]
|
| 115 |
+
confidence = model_prediction["confidence"]
|
| 116 |
|
| 117 |
# Get key drivers
|
| 118 |
+
key_drivers_raw = linker_findings.get("key_drivers", [])
|
| 119 |
key_drivers = [
|
| 120 |
{
|
| 121 |
+
"biomarker": kd.get("biomarker"),
|
| 122 |
+
"value": kd.get("value"),
|
| 123 |
+
"contribution": kd.get("contribution"),
|
| 124 |
+
"explanation": kd.get("explanation"),
|
| 125 |
+
"evidence": kd.get("evidence", "")[:200], # Truncate
|
| 126 |
}
|
| 127 |
for kd in key_drivers_raw
|
| 128 |
]
|
|
|
|
| 131 |
"primary_disease": disease,
|
| 132 |
"confidence": confidence,
|
| 133 |
"key_drivers": key_drivers,
|
| 134 |
+
"mechanism_summary": disease_explanation.get("mechanism_summary", disease_explanation.get("summary", "")),
|
| 135 |
+
"pathophysiology": disease_explanation.get("pathophysiology", ""),
|
| 136 |
+
"pdf_references": disease_explanation.get("citations", []),
|
| 137 |
}
|
| 138 |
|
| 139 |
def _build_biomarker_flags(self, findings: dict) -> list[dict]:
|
| 140 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 141 |
+
return biomarker_analysis.get("biomarker_flags", [])
|
| 142 |
|
| 143 |
def _build_key_drivers(self, findings: dict) -> list[dict]:
|
| 144 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 145 |
+
return linker_findings.get("key_drivers", [])
|
| 146 |
|
| 147 |
def _build_disease_explanation(self, findings: dict) -> dict:
|
| 148 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 149 |
return {
|
| 150 |
+
"pathophysiology": disease_explanation.get("pathophysiology", ""),
|
| 151 |
+
"citations": disease_explanation.get("citations", []),
|
| 152 |
+
"retrieved_chunks": disease_explanation.get("retrieved_chunks"),
|
| 153 |
}
|
| 154 |
|
| 155 |
def _build_recommendations(self, findings: dict) -> dict:
|
|
|
|
| 157 |
guidelines = findings.get("Clinical Guidelines", {})
|
| 158 |
|
| 159 |
return {
|
| 160 |
+
"immediate_actions": guidelines.get("immediate_actions", []),
|
| 161 |
+
"lifestyle_changes": guidelines.get("lifestyle_changes", []),
|
| 162 |
+
"monitoring": guidelines.get("monitoring", []),
|
| 163 |
+
"guideline_citations": guidelines.get("guideline_citations", []),
|
| 164 |
}
|
| 165 |
|
| 166 |
def _build_confidence_assessment(self, findings: dict) -> dict:
|
|
|
|
| 168 |
assessment = findings.get("Confidence Assessor", {})
|
| 169 |
|
| 170 |
return {
|
| 171 |
+
"prediction_reliability": assessment.get("prediction_reliability", "UNKNOWN"),
|
| 172 |
+
"evidence_strength": assessment.get("evidence_strength", "UNKNOWN"),
|
| 173 |
+
"limitations": assessment.get("limitations", []),
|
| 174 |
+
"recommendation": assessment.get("recommendation", "Consult healthcare provider"),
|
| 175 |
+
"assessment_summary": assessment.get("assessment_summary", ""),
|
| 176 |
+
"alternative_diagnoses": assessment.get("alternative_diagnoses", []),
|
| 177 |
}
|
| 178 |
|
| 179 |
def _build_alternative_diagnoses(self, findings: dict) -> list[dict]:
|
| 180 |
assessment = findings.get("Confidence Assessor", {})
|
| 181 |
+
return assessment.get("alternative_diagnoses", [])
|
| 182 |
|
| 183 |
def _build_safety_alerts(self, findings: dict) -> list[dict]:
|
| 184 |
"""Build safety alerts section"""
|
| 185 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 186 |
+
return biomarker_analysis.get("safety_alerts", [])
|
| 187 |
|
| 188 |
def _build_metadata(self, state: GuildState) -> dict:
|
| 189 |
"""Build metadata section"""
|
|
|
|
| 193 |
"timestamp": datetime.now().isoformat(),
|
| 194 |
"system_version": "MediGuard AI RAG-Helper v1.0",
|
| 195 |
"sop_version": "Baseline",
|
| 196 |
+
"agents_executed": [output.agent_name for output in state.get("agent_outputs", [])],
|
| 197 |
+
"disclaimer": "This is an AI-assisted analysis tool for patient self-assessment. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always consult qualified healthcare providers for medical decisions.",
|
| 198 |
}
|
| 199 |
|
| 200 |
+
def _generate_narrative_summary(self, model_prediction, findings: dict, response: dict) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
"""Generate a patient-friendly narrative summary using LLM"""
|
| 202 |
|
| 203 |
+
disease = model_prediction["disease"]
|
| 204 |
+
confidence = model_prediction["confidence"]
|
| 205 |
+
reliability = response["confidence_assessment"]["prediction_reliability"]
|
| 206 |
|
| 207 |
# Get key points
|
| 208 |
+
critical_count = response["patient_summary"]["critical_values"]
|
| 209 |
+
abnormal_count = response["patient_summary"]["biomarkers_out_of_range"]
|
| 210 |
+
key_drivers = response["prediction_explanation"]["key_drivers"]
|
| 211 |
+
|
| 212 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 213 |
+
[
|
| 214 |
+
(
|
| 215 |
+
"system",
|
| 216 |
+
"""You are a medical AI assistant explaining test results to a patient.
|
| 217 |
Write a clear, compassionate 3-4 sentence summary that:
|
| 218 |
1. States the predicted condition and confidence level
|
| 219 |
2. Highlights the most important biomarker findings
|
| 220 |
3. Emphasizes the need for medical consultation
|
| 221 |
4. Offers reassurance while being honest about findings
|
| 222 |
|
| 223 |
+
Use patient-friendly language. Avoid medical jargon. Be supportive and clear.""",
|
| 224 |
+
),
|
| 225 |
+
(
|
| 226 |
+
"human",
|
| 227 |
+
"""Disease Predicted: {disease}
|
| 228 |
Model Confidence: {confidence:.1%}
|
| 229 |
Overall Reliability: {reliability}
|
| 230 |
Critical Values: {critical}
|
| 231 |
Out-of-Range Values: {abnormal}
|
| 232 |
Top Biomarker Drivers: {drivers}
|
| 233 |
|
| 234 |
+
Write a compassionate patient summary.""",
|
| 235 |
+
),
|
| 236 |
+
]
|
| 237 |
+
)
|
| 238 |
|
| 239 |
chain = prompt | self.llm
|
| 240 |
|
| 241 |
try:
|
| 242 |
+
driver_names = [kd["biomarker"] for kd in key_drivers[:3]]
|
| 243 |
+
|
| 244 |
+
response_obj = chain.invoke(
|
| 245 |
+
{
|
| 246 |
+
"disease": disease,
|
| 247 |
+
"confidence": confidence,
|
| 248 |
+
"reliability": reliability,
|
| 249 |
+
"critical": critical_count,
|
| 250 |
+
"abnormal": abnormal_count,
|
| 251 |
+
"drivers": ", ".join(driver_names) if driver_names else "Multiple biomarkers",
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
|
| 255 |
return response_obj.content.strip()
|
| 256 |
|
src/biomarker_normalization.py
CHANGED
|
@@ -3,14 +3,12 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Shared biomarker normalization utilities
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
|
| 7 |
# Normalization map for biomarker aliases to canonical names.
|
| 8 |
NORMALIZATION_MAP: dict[str, str] = {
|
| 9 |
# Glucose variations
|
| 10 |
"glucose": "Glucose",
|
| 11 |
"bloodsugar": "Glucose",
|
| 12 |
"bloodglucose": "Glucose",
|
| 13 |
-
|
| 14 |
# Lipid panel
|
| 15 |
"cholesterol": "Cholesterol",
|
| 16 |
"totalcholesterol": "Cholesterol",
|
|
@@ -20,17 +18,14 @@ NORMALIZATION_MAP: dict[str, str] = {
|
|
| 20 |
"ldlcholesterol": "LDL Cholesterol",
|
| 21 |
"hdl": "HDL Cholesterol",
|
| 22 |
"hdlcholesterol": "HDL Cholesterol",
|
| 23 |
-
|
| 24 |
# Diabetes markers
|
| 25 |
"hba1c": "HbA1c",
|
| 26 |
"a1c": "HbA1c",
|
| 27 |
"hemoglobina1c": "HbA1c",
|
| 28 |
"insulin": "Insulin",
|
| 29 |
-
|
| 30 |
# Body metrics
|
| 31 |
"bmi": "BMI",
|
| 32 |
"bodymassindex": "BMI",
|
| 33 |
-
|
| 34 |
# Complete Blood Count (CBC)
|
| 35 |
"hemoglobin": "Hemoglobin",
|
| 36 |
"hgb": "Hemoglobin",
|
|
@@ -45,14 +40,12 @@ NORMALIZATION_MAP: dict[str, str] = {
|
|
| 45 |
"redcells": "Red Blood Cells",
|
| 46 |
"hematocrit": "Hematocrit",
|
| 47 |
"hct": "Hematocrit",
|
| 48 |
-
|
| 49 |
# Red blood cell indices
|
| 50 |
"mcv": "Mean Corpuscular Volume",
|
| 51 |
"meancorpuscularvolume": "Mean Corpuscular Volume",
|
| 52 |
"mch": "Mean Corpuscular Hemoglobin",
|
| 53 |
"meancorpuscularhemoglobin": "Mean Corpuscular Hemoglobin",
|
| 54 |
"mchc": "Mean Corpuscular Hemoglobin Concentration",
|
| 55 |
-
|
| 56 |
# Cardiovascular
|
| 57 |
"heartrate": "Heart Rate",
|
| 58 |
"hr": "Heart Rate",
|
|
@@ -64,7 +57,6 @@ NORMALIZATION_MAP: dict[str, str] = {
|
|
| 64 |
"diastolic": "Diastolic Blood Pressure",
|
| 65 |
"dbp": "Diastolic Blood Pressure",
|
| 66 |
"troponin": "Troponin",
|
| 67 |
-
|
| 68 |
# Inflammation and liver
|
| 69 |
"creactiveprotein": "C-reactive Protein",
|
| 70 |
"crp": "C-reactive Protein",
|
|
@@ -72,10 +64,8 @@ NORMALIZATION_MAP: dict[str, str] = {
|
|
| 72 |
"alanineaminotransferase": "ALT",
|
| 73 |
"ast": "AST",
|
| 74 |
"aspartateaminotransferase": "AST",
|
| 75 |
-
|
| 76 |
# Kidney
|
| 77 |
"creatinine": "Creatinine",
|
| 78 |
-
|
| 79 |
# Thyroid
|
| 80 |
"tsh": "TSH",
|
| 81 |
"thyroidstimulatinghormone": "TSH",
|
|
@@ -83,7 +73,6 @@ NORMALIZATION_MAP: dict[str, str] = {
|
|
| 83 |
"triiodothyronine": "T3",
|
| 84 |
"t4": "T4",
|
| 85 |
"thyroxine": "T4",
|
| 86 |
-
|
| 87 |
# Electrolytes
|
| 88 |
"sodium": "Sodium",
|
| 89 |
"na": "Sodium",
|
|
@@ -95,14 +84,12 @@ NORMALIZATION_MAP: dict[str, str] = {
|
|
| 95 |
"cl": "Chloride",
|
| 96 |
"bicarbonate": "Bicarbonate",
|
| 97 |
"hco3": "Bicarbonate",
|
| 98 |
-
|
| 99 |
# Kidney / Metabolic
|
| 100 |
"urea": "Urea",
|
| 101 |
"bun": "BUN",
|
| 102 |
"bloodureanitrogen": "BUN",
|
| 103 |
"buncreatinineratio": "BUN_Creatinine_Ratio",
|
| 104 |
"uricacid": "Uric_Acid",
|
| 105 |
-
|
| 106 |
# Liver / Protein
|
| 107 |
"totalprotein": "Total_Protein",
|
| 108 |
"albumin": "Albumin",
|
|
@@ -113,7 +100,6 @@ NORMALIZATION_MAP: dict[str, str] = {
|
|
| 113 |
"bilirubin": "Bilirubin_Total",
|
| 114 |
"alp": "ALP",
|
| 115 |
"alkalinephosphatase": "ALP",
|
| 116 |
-
|
| 117 |
# Lipids
|
| 118 |
"vldl": "VLDL",
|
| 119 |
}
|
|
|
|
| 3 |
Shared biomarker normalization utilities
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
# Normalization map for biomarker aliases to canonical names.
|
| 7 |
NORMALIZATION_MAP: dict[str, str] = {
|
| 8 |
# Glucose variations
|
| 9 |
"glucose": "Glucose",
|
| 10 |
"bloodsugar": "Glucose",
|
| 11 |
"bloodglucose": "Glucose",
|
|
|
|
| 12 |
# Lipid panel
|
| 13 |
"cholesterol": "Cholesterol",
|
| 14 |
"totalcholesterol": "Cholesterol",
|
|
|
|
| 18 |
"ldlcholesterol": "LDL Cholesterol",
|
| 19 |
"hdl": "HDL Cholesterol",
|
| 20 |
"hdlcholesterol": "HDL Cholesterol",
|
|
|
|
| 21 |
# Diabetes markers
|
| 22 |
"hba1c": "HbA1c",
|
| 23 |
"a1c": "HbA1c",
|
| 24 |
"hemoglobina1c": "HbA1c",
|
| 25 |
"insulin": "Insulin",
|
|
|
|
| 26 |
# Body metrics
|
| 27 |
"bmi": "BMI",
|
| 28 |
"bodymassindex": "BMI",
|
|
|
|
| 29 |
# Complete Blood Count (CBC)
|
| 30 |
"hemoglobin": "Hemoglobin",
|
| 31 |
"hgb": "Hemoglobin",
|
|
|
|
| 40 |
"redcells": "Red Blood Cells",
|
| 41 |
"hematocrit": "Hematocrit",
|
| 42 |
"hct": "Hematocrit",
|
|
|
|
| 43 |
# Red blood cell indices
|
| 44 |
"mcv": "Mean Corpuscular Volume",
|
| 45 |
"meancorpuscularvolume": "Mean Corpuscular Volume",
|
| 46 |
"mch": "Mean Corpuscular Hemoglobin",
|
| 47 |
"meancorpuscularhemoglobin": "Mean Corpuscular Hemoglobin",
|
| 48 |
"mchc": "Mean Corpuscular Hemoglobin Concentration",
|
|
|
|
| 49 |
# Cardiovascular
|
| 50 |
"heartrate": "Heart Rate",
|
| 51 |
"hr": "Heart Rate",
|
|
|
|
| 57 |
"diastolic": "Diastolic Blood Pressure",
|
| 58 |
"dbp": "Diastolic Blood Pressure",
|
| 59 |
"troponin": "Troponin",
|
|
|
|
| 60 |
# Inflammation and liver
|
| 61 |
"creactiveprotein": "C-reactive Protein",
|
| 62 |
"crp": "C-reactive Protein",
|
|
|
|
| 64 |
"alanineaminotransferase": "ALT",
|
| 65 |
"ast": "AST",
|
| 66 |
"aspartateaminotransferase": "AST",
|
|
|
|
| 67 |
# Kidney
|
| 68 |
"creatinine": "Creatinine",
|
|
|
|
| 69 |
# Thyroid
|
| 70 |
"tsh": "TSH",
|
| 71 |
"thyroidstimulatinghormone": "TSH",
|
|
|
|
| 73 |
"triiodothyronine": "T3",
|
| 74 |
"t4": "T4",
|
| 75 |
"thyroxine": "T4",
|
|
|
|
| 76 |
# Electrolytes
|
| 77 |
"sodium": "Sodium",
|
| 78 |
"na": "Sodium",
|
|
|
|
| 84 |
"cl": "Chloride",
|
| 85 |
"bicarbonate": "Bicarbonate",
|
| 86 |
"hco3": "Bicarbonate",
|
|
|
|
| 87 |
# Kidney / Metabolic
|
| 88 |
"urea": "Urea",
|
| 89 |
"bun": "BUN",
|
| 90 |
"bloodureanitrogen": "BUN",
|
| 91 |
"buncreatinineratio": "BUN_Creatinine_Ratio",
|
| 92 |
"uricacid": "Uric_Acid",
|
|
|
|
| 93 |
# Liver / Protein
|
| 94 |
"totalprotein": "Total_Protein",
|
| 95 |
"albumin": "Albumin",
|
|
|
|
| 100 |
"bilirubin": "Bilirubin_Total",
|
| 101 |
"alp": "ALP",
|
| 102 |
"alkalinephosphatase": "ALP",
|
|
|
|
| 103 |
# Lipids
|
| 104 |
"vldl": "VLDL",
|
| 105 |
}
|
src/biomarker_validator.py
CHANGED
|
@@ -16,24 +16,20 @@ class BiomarkerValidator:
|
|
| 16 |
"""Load biomarker reference ranges from JSON file"""
|
| 17 |
ref_path = Path(__file__).parent.parent / reference_file
|
| 18 |
with open(ref_path) as f:
|
| 19 |
-
self.references = json.load(f)[
|
| 20 |
|
| 21 |
def validate_biomarker(
|
| 22 |
-
self,
|
| 23 |
-
name: str,
|
| 24 |
-
value: float,
|
| 25 |
-
gender: str | None = None,
|
| 26 |
-
threshold_pct: float = 0.0
|
| 27 |
) -> BiomarkerFlag:
|
| 28 |
"""
|
| 29 |
Validate a single biomarker value against reference ranges.
|
| 30 |
-
|
| 31 |
Args:
|
| 32 |
name: Biomarker name
|
| 33 |
value: Measured value
|
| 34 |
gender: "male" or "female" (for gender-specific ranges)
|
| 35 |
threshold_pct: Only flag LOW/HIGH if deviation from boundary exceeds this fraction (e.g. 0.15 = 15%)
|
| 36 |
-
|
| 37 |
Returns:
|
| 38 |
BiomarkerFlag object with status and warnings
|
| 39 |
"""
|
|
@@ -44,27 +40,27 @@ class BiomarkerValidator:
|
|
| 44 |
unit="unknown",
|
| 45 |
status="UNKNOWN",
|
| 46 |
reference_range="No reference data available",
|
| 47 |
-
warning=f"No reference range found for {name}"
|
| 48 |
)
|
| 49 |
|
| 50 |
ref = self.references[name]
|
| 51 |
-
unit = ref[
|
| 52 |
|
| 53 |
# Handle gender-specific ranges
|
| 54 |
-
if ref.get(
|
| 55 |
-
if gender.lower() in [
|
| 56 |
-
normal = ref[
|
| 57 |
-
elif gender.lower() in [
|
| 58 |
-
normal = ref[
|
| 59 |
else:
|
| 60 |
-
normal = ref[
|
| 61 |
else:
|
| 62 |
-
normal = ref[
|
| 63 |
|
| 64 |
-
min_val = normal.get(
|
| 65 |
-
max_val = normal.get(
|
| 66 |
-
critical_low = ref.get(
|
| 67 |
-
critical_high = ref.get(
|
| 68 |
|
| 69 |
# Determine status
|
| 70 |
status = "NORMAL"
|
|
@@ -92,28 +88,20 @@ class BiomarkerValidator:
|
|
| 92 |
reference_range = f"{min_val}-{max_val} {unit}"
|
| 93 |
|
| 94 |
return BiomarkerFlag(
|
| 95 |
-
name=name,
|
| 96 |
-
value=value,
|
| 97 |
-
unit=unit,
|
| 98 |
-
status=status,
|
| 99 |
-
reference_range=reference_range,
|
| 100 |
-
warning=warning
|
| 101 |
)
|
| 102 |
|
| 103 |
def validate_all(
|
| 104 |
-
self,
|
| 105 |
-
biomarkers: dict[str, float],
|
| 106 |
-
gender: str | None = None,
|
| 107 |
-
threshold_pct: float = 0.0
|
| 108 |
) -> tuple[list[BiomarkerFlag], list[SafetyAlert]]:
|
| 109 |
"""
|
| 110 |
Validate all biomarker values.
|
| 111 |
-
|
| 112 |
Args:
|
| 113 |
biomarkers: Dict of biomarker name -> value
|
| 114 |
gender: "male" or "female" (for gender-specific ranges)
|
| 115 |
threshold_pct: Only flag LOW/HIGH if deviation exceeds this fraction (e.g. 0.15 = 15%)
|
| 116 |
-
|
| 117 |
Returns:
|
| 118 |
Tuple of (biomarker_flags, safety_alerts)
|
| 119 |
"""
|
|
@@ -126,20 +114,24 @@ class BiomarkerValidator:
|
|
| 126 |
|
| 127 |
# Generate safety alerts for critical values
|
| 128 |
if flag.status in ["CRITICAL_LOW", "CRITICAL_HIGH"]:
|
| 129 |
-
alerts.append(
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
| 135 |
elif flag.status in ["LOW", "HIGH"]:
|
| 136 |
severity = "HIGH" if "severe" in (flag.warning or "").lower() else "MEDIUM"
|
| 137 |
-
alerts.append(
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
|
| 144 |
return flags, alerts
|
| 145 |
|
|
@@ -154,40 +146,57 @@ class BiomarkerValidator:
|
|
| 154 |
def get_disease_relevant_biomarkers(self, disease: str) -> list[str]:
|
| 155 |
"""
|
| 156 |
Get list of biomarkers most relevant to a specific disease.
|
| 157 |
-
|
| 158 |
This is a simplified mapping - in production, this would be more sophisticated.
|
| 159 |
"""
|
| 160 |
disease_map = {
|
| 161 |
-
"Diabetes": [
|
| 162 |
-
"Glucose", "HbA1c", "Insulin", "BMI",
|
| 163 |
-
"Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
|
| 164 |
-
],
|
| 165 |
"Type 2 Diabetes": [
|
| 166 |
-
"Glucose",
|
| 167 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
],
|
| 169 |
"Type 1 Diabetes": [
|
| 170 |
-
"Glucose",
|
| 171 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
],
|
| 173 |
"Anemia": [
|
| 174 |
-
"Hemoglobin",
|
| 175 |
-
"
|
| 176 |
-
"
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
"
|
| 180 |
],
|
|
|
|
| 181 |
"Thalassemia": [
|
| 182 |
-
"Hemoglobin",
|
| 183 |
-
"
|
|
|
|
|
|
|
|
|
|
| 184 |
],
|
| 185 |
"Heart Disease": [
|
| 186 |
-
"Cholesterol",
|
| 187 |
-
"
|
| 188 |
-
"
|
| 189 |
-
"
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
}
|
| 192 |
|
| 193 |
return disease_map.get(disease, [])
|
|
|
|
| 16 |
"""Load biomarker reference ranges from JSON file"""
|
| 17 |
ref_path = Path(__file__).parent.parent / reference_file
|
| 18 |
with open(ref_path) as f:
|
| 19 |
+
self.references = json.load(f)["biomarkers"]
|
| 20 |
|
| 21 |
def validate_biomarker(
|
| 22 |
+
self, name: str, value: float, gender: str | None = None, threshold_pct: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
) -> BiomarkerFlag:
|
| 24 |
"""
|
| 25 |
Validate a single biomarker value against reference ranges.
|
| 26 |
+
|
| 27 |
Args:
|
| 28 |
name: Biomarker name
|
| 29 |
value: Measured value
|
| 30 |
gender: "male" or "female" (for gender-specific ranges)
|
| 31 |
threshold_pct: Only flag LOW/HIGH if deviation from boundary exceeds this fraction (e.g. 0.15 = 15%)
|
| 32 |
+
|
| 33 |
Returns:
|
| 34 |
BiomarkerFlag object with status and warnings
|
| 35 |
"""
|
|
|
|
| 40 |
unit="unknown",
|
| 41 |
status="UNKNOWN",
|
| 42 |
reference_range="No reference data available",
|
| 43 |
+
warning=f"No reference range found for {name}",
|
| 44 |
)
|
| 45 |
|
| 46 |
ref = self.references[name]
|
| 47 |
+
unit = ref["unit"]
|
| 48 |
|
| 49 |
# Handle gender-specific ranges
|
| 50 |
+
if ref.get("gender_specific", False) and gender:
|
| 51 |
+
if gender.lower() in ["male", "m"]:
|
| 52 |
+
normal = ref["normal_range"]["male"]
|
| 53 |
+
elif gender.lower() in ["female", "f"]:
|
| 54 |
+
normal = ref["normal_range"]["female"]
|
| 55 |
else:
|
| 56 |
+
normal = ref["normal_range"]
|
| 57 |
else:
|
| 58 |
+
normal = ref["normal_range"]
|
| 59 |
|
| 60 |
+
min_val = normal.get("min", 0)
|
| 61 |
+
max_val = normal.get("max", float("inf"))
|
| 62 |
+
critical_low = ref.get("critical_low")
|
| 63 |
+
critical_high = ref.get("critical_high")
|
| 64 |
|
| 65 |
# Determine status
|
| 66 |
status = "NORMAL"
|
|
|
|
| 88 |
reference_range = f"{min_val}-{max_val} {unit}"
|
| 89 |
|
| 90 |
return BiomarkerFlag(
|
| 91 |
+
name=name, value=value, unit=unit, status=status, reference_range=reference_range, warning=warning
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
def validate_all(
|
| 95 |
+
self, biomarkers: dict[str, float], gender: str | None = None, threshold_pct: float = 0.0
|
|
|
|
|
|
|
|
|
|
| 96 |
) -> tuple[list[BiomarkerFlag], list[SafetyAlert]]:
|
| 97 |
"""
|
| 98 |
Validate all biomarker values.
|
| 99 |
+
|
| 100 |
Args:
|
| 101 |
biomarkers: Dict of biomarker name -> value
|
| 102 |
gender: "male" or "female" (for gender-specific ranges)
|
| 103 |
threshold_pct: Only flag LOW/HIGH if deviation exceeds this fraction (e.g. 0.15 = 15%)
|
| 104 |
+
|
| 105 |
Returns:
|
| 106 |
Tuple of (biomarker_flags, safety_alerts)
|
| 107 |
"""
|
|
|
|
| 114 |
|
| 115 |
# Generate safety alerts for critical values
|
| 116 |
if flag.status in ["CRITICAL_LOW", "CRITICAL_HIGH"]:
|
| 117 |
+
alerts.append(
|
| 118 |
+
SafetyAlert(
|
| 119 |
+
severity="CRITICAL",
|
| 120 |
+
biomarker=name,
|
| 121 |
+
message=flag.warning or f"{name} at critical level",
|
| 122 |
+
action="SEEK IMMEDIATE MEDICAL ATTENTION",
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
elif flag.status in ["LOW", "HIGH"]:
|
| 126 |
severity = "HIGH" if "severe" in (flag.warning or "").lower() else "MEDIUM"
|
| 127 |
+
alerts.append(
|
| 128 |
+
SafetyAlert(
|
| 129 |
+
severity=severity,
|
| 130 |
+
biomarker=name,
|
| 131 |
+
message=flag.warning or f"{name} out of normal range",
|
| 132 |
+
action="Consult with healthcare provider",
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
|
| 136 |
return flags, alerts
|
| 137 |
|
|
|
|
| 146 |
def get_disease_relevant_biomarkers(self, disease: str) -> list[str]:
|
| 147 |
"""
|
| 148 |
Get list of biomarkers most relevant to a specific disease.
|
| 149 |
+
|
| 150 |
This is a simplified mapping - in production, this would be more sophisticated.
|
| 151 |
"""
|
| 152 |
disease_map = {
|
| 153 |
+
"Diabetes": ["Glucose", "HbA1c", "Insulin", "BMI", "Triglycerides", "HDL Cholesterol", "LDL Cholesterol"],
|
|
|
|
|
|
|
|
|
|
| 154 |
"Type 2 Diabetes": [
|
| 155 |
+
"Glucose",
|
| 156 |
+
"HbA1c",
|
| 157 |
+
"Insulin",
|
| 158 |
+
"BMI",
|
| 159 |
+
"Triglycerides",
|
| 160 |
+
"HDL Cholesterol",
|
| 161 |
+
"LDL Cholesterol",
|
| 162 |
],
|
| 163 |
"Type 1 Diabetes": [
|
| 164 |
+
"Glucose",
|
| 165 |
+
"HbA1c",
|
| 166 |
+
"Insulin",
|
| 167 |
+
"BMI",
|
| 168 |
+
"Triglycerides",
|
| 169 |
+
"HDL Cholesterol",
|
| 170 |
+
"LDL Cholesterol",
|
| 171 |
],
|
| 172 |
"Anemia": [
|
| 173 |
+
"Hemoglobin",
|
| 174 |
+
"Red Blood Cells",
|
| 175 |
+
"Hematocrit",
|
| 176 |
+
"Mean Corpuscular Volume",
|
| 177 |
+
"Mean Corpuscular Hemoglobin",
|
| 178 |
+
"Mean Corpuscular Hemoglobin Concentration",
|
| 179 |
],
|
| 180 |
+
"Thrombocytopenia": ["Platelets", "White Blood Cells", "Hemoglobin"],
|
| 181 |
"Thalassemia": [
|
| 182 |
+
"Hemoglobin",
|
| 183 |
+
"Red Blood Cells",
|
| 184 |
+
"Mean Corpuscular Volume",
|
| 185 |
+
"Mean Corpuscular Hemoglobin",
|
| 186 |
+
"Hematocrit",
|
| 187 |
],
|
| 188 |
"Heart Disease": [
|
| 189 |
+
"Cholesterol",
|
| 190 |
+
"LDL Cholesterol",
|
| 191 |
+
"HDL Cholesterol",
|
| 192 |
+
"Triglycerides",
|
| 193 |
+
"Troponin",
|
| 194 |
+
"C-reactive Protein",
|
| 195 |
+
"Systolic Blood Pressure",
|
| 196 |
+
"Diastolic Blood Pressure",
|
| 197 |
+
"Heart Rate",
|
| 198 |
+
"BMI",
|
| 199 |
+
],
|
| 200 |
}
|
| 201 |
|
| 202 |
return disease_map.get(disease, [])
|
src/config.py
CHANGED
|
@@ -17,24 +17,16 @@ class ExplanationSOP(BaseModel):
|
|
| 17 |
|
| 18 |
# === Agent Behavior Parameters ===
|
| 19 |
biomarker_analyzer_threshold: float = Field(
|
| 20 |
-
default=0.15,
|
| 21 |
-
description="Percentage deviation from normal range to trigger a warning flag (0.15 = 15%)"
|
| 22 |
)
|
| 23 |
|
| 24 |
disease_explainer_k: int = Field(
|
| 25 |
-
default=5,
|
| 26 |
-
description="Number of top PDF chunks to retrieve for disease explanation"
|
| 27 |
)
|
| 28 |
|
| 29 |
-
linker_retrieval_k: int = Field(
|
| 30 |
-
default=3,
|
| 31 |
-
description="Number of chunks for biomarker-disease linking"
|
| 32 |
-
)
|
| 33 |
|
| 34 |
-
guideline_retrieval_k: int = Field(
|
| 35 |
-
default=3,
|
| 36 |
-
description="Number of chunks for clinical guidelines"
|
| 37 |
-
)
|
| 38 |
|
| 39 |
# === Prompts (Evolvable) ===
|
| 40 |
planner_prompt: str = Field(
|
|
@@ -48,7 +40,7 @@ Available specialist agents:
|
|
| 48 |
- Confidence Assessor: Evaluates prediction reliability
|
| 49 |
|
| 50 |
Output a JSON with key 'plan' containing a list of tasks. Each task must have 'agent', 'task_description', and 'dependencies' keys.""",
|
| 51 |
-
description="System prompt for the Planner Agent"
|
| 52 |
)
|
| 53 |
|
| 54 |
synthesizer_prompt: str = Field(
|
|
@@ -63,45 +55,36 @@ Output a JSON with key 'plan' containing a list of tasks. Each task must have 'a
|
|
| 63 |
- Be transparent about limitations and uncertainties
|
| 64 |
|
| 65 |
Structure your output as specified in the output schema.""",
|
| 66 |
-
description="System prompt for the Response Synthesizer"
|
| 67 |
)
|
| 68 |
|
| 69 |
explainer_detail_level: Literal["concise", "detailed", "comprehensive"] = Field(
|
| 70 |
-
default="detailed",
|
| 71 |
-
description="Level of detail in disease mechanism explanations"
|
| 72 |
)
|
| 73 |
|
| 74 |
# === Feature Flags ===
|
| 75 |
use_guideline_agent: bool = Field(
|
| 76 |
-
default=True,
|
| 77 |
-
description="Whether to retrieve clinical guidelines and recommendations"
|
| 78 |
)
|
| 79 |
|
| 80 |
include_alternative_diagnoses: bool = Field(
|
| 81 |
-
default=True,
|
| 82 |
-
description="Whether to discuss alternative diagnoses from prediction probabilities"
|
| 83 |
)
|
| 84 |
|
| 85 |
-
require_pdf_citations: bool = Field(
|
| 86 |
-
default=True,
|
| 87 |
-
description="Whether to require PDF citations for all claims"
|
| 88 |
-
)
|
| 89 |
|
| 90 |
use_confidence_assessor: bool = Field(
|
| 91 |
-
default=True,
|
| 92 |
-
description="Whether to evaluate and report prediction confidence"
|
| 93 |
)
|
| 94 |
|
| 95 |
# === Safety Settings ===
|
| 96 |
critical_value_alert_mode: Literal["strict", "moderate", "permissive"] = Field(
|
| 97 |
-
default="strict",
|
| 98 |
-
description="Threshold for critical value alerts"
|
| 99 |
)
|
| 100 |
|
| 101 |
# === Model Selection ===
|
| 102 |
synthesizer_model: str = Field(
|
| 103 |
-
default="default",
|
| 104 |
-
description="LLM to use for final response synthesis (uses provider default)"
|
| 105 |
)
|
| 106 |
|
| 107 |
|
|
@@ -117,5 +100,5 @@ BASELINE_SOP = ExplanationSOP(
|
|
| 117 |
require_pdf_citations=True,
|
| 118 |
use_confidence_assessor=True,
|
| 119 |
critical_value_alert_mode="strict",
|
| 120 |
-
synthesizer_model="default"
|
| 121 |
)
|
|
|
|
| 17 |
|
| 18 |
# === Agent Behavior Parameters ===
|
| 19 |
biomarker_analyzer_threshold: float = Field(
|
| 20 |
+
default=0.15, description="Percentage deviation from normal range to trigger a warning flag (0.15 = 15%)"
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
disease_explainer_k: int = Field(
|
| 24 |
+
default=5, description="Number of top PDF chunks to retrieve for disease explanation"
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
+
linker_retrieval_k: int = Field(default=3, description="Number of chunks for biomarker-disease linking")
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
guideline_retrieval_k: int = Field(default=3, description="Number of chunks for clinical guidelines")
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# === Prompts (Evolvable) ===
|
| 32 |
planner_prompt: str = Field(
|
|
|
|
| 40 |
- Confidence Assessor: Evaluates prediction reliability
|
| 41 |
|
| 42 |
Output a JSON with key 'plan' containing a list of tasks. Each task must have 'agent', 'task_description', and 'dependencies' keys.""",
|
| 43 |
+
description="System prompt for the Planner Agent",
|
| 44 |
)
|
| 45 |
|
| 46 |
synthesizer_prompt: str = Field(
|
|
|
|
| 55 |
- Be transparent about limitations and uncertainties
|
| 56 |
|
| 57 |
Structure your output as specified in the output schema.""",
|
| 58 |
+
description="System prompt for the Response Synthesizer",
|
| 59 |
)
|
| 60 |
|
| 61 |
explainer_detail_level: Literal["concise", "detailed", "comprehensive"] = Field(
|
| 62 |
+
default="detailed", description="Level of detail in disease mechanism explanations"
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
# === Feature Flags ===
|
| 66 |
use_guideline_agent: bool = Field(
|
| 67 |
+
default=True, description="Whether to retrieve clinical guidelines and recommendations"
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
include_alternative_diagnoses: bool = Field(
|
| 71 |
+
default=True, description="Whether to discuss alternative diagnoses from prediction probabilities"
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
+
require_pdf_citations: bool = Field(default=True, description="Whether to require PDF citations for all claims")
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
use_confidence_assessor: bool = Field(
|
| 77 |
+
default=True, description="Whether to evaluate and report prediction confidence"
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
# === Safety Settings ===
|
| 81 |
critical_value_alert_mode: Literal["strict", "moderate", "permissive"] = Field(
|
| 82 |
+
default="strict", description="Threshold for critical value alerts"
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
# === Model Selection ===
|
| 86 |
synthesizer_model: str = Field(
|
| 87 |
+
default="default", description="LLM to use for final response synthesis (uses provider default)"
|
|
|
|
| 88 |
)
|
| 89 |
|
| 90 |
|
|
|
|
| 100 |
require_pdf_citations=True,
|
| 101 |
use_confidence_assessor=True,
|
| 102 |
critical_value_alert_mode="strict",
|
| 103 |
+
synthesizer_model="default",
|
| 104 |
)
|
src/database.py
CHANGED
|
@@ -17,6 +17,7 @@ from src.settings import get_settings
|
|
| 17 |
|
| 18 |
class Base(DeclarativeBase):
|
| 19 |
"""Shared declarative base for all ORM models."""
|
|
|
|
| 20 |
pass
|
| 21 |
|
| 22 |
|
|
|
|
| 17 |
|
| 18 |
class Base(DeclarativeBase):
|
| 19 |
"""Shared declarative base for all ORM models."""
|
| 20 |
+
|
| 21 |
pass
|
| 22 |
|
| 23 |
|
src/evaluation/__init__.py
CHANGED
|
@@ -15,12 +15,12 @@ from .evaluators import (
|
|
| 15 |
)
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
]
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
+
"EvaluationResult",
|
| 19 |
+
"GradedScore",
|
| 20 |
+
"evaluate_actionability",
|
| 21 |
+
"evaluate_clarity",
|
| 22 |
+
"evaluate_clinical_accuracy",
|
| 23 |
+
"evaluate_evidence_grounding",
|
| 24 |
+
"evaluate_safety_completeness",
|
| 25 |
+
"run_full_evaluation",
|
| 26 |
]
|
src/evaluation/evaluators.py
CHANGED
|
@@ -17,7 +17,7 @@ IMPORTANT LIMITATIONS:
|
|
| 17 |
|
| 18 |
Usage:
|
| 19 |
from src.evaluation.evaluators import run_5d_evaluation
|
| 20 |
-
|
| 21 |
result = run_5d_evaluation(final_response, pubmed_context)
|
| 22 |
print(f"Average score: {result.average_score():.2f}")
|
| 23 |
"""
|
|
@@ -37,12 +37,14 @@ DETERMINISTIC_MODE = os.environ.get("EVALUATION_DETERMINISTIC", "false").lower()
|
|
| 37 |
|
| 38 |
class GradedScore(BaseModel):
|
| 39 |
"""Structured score with justification"""
|
|
|
|
| 40 |
score: float = Field(description="Score from 0.0 to 1.0", ge=0.0, le=1.0)
|
| 41 |
reasoning: str = Field(description="Justification for the score")
|
| 42 |
|
| 43 |
|
| 44 |
class EvaluationResult(BaseModel):
|
| 45 |
"""Complete 5D evaluation result"""
|
|
|
|
| 46 |
clinical_accuracy: GradedScore
|
| 47 |
evidence_grounding: GradedScore
|
| 48 |
actionability: GradedScore
|
|
@@ -56,7 +58,7 @@ class EvaluationResult(BaseModel):
|
|
| 56 |
self.evidence_grounding.score,
|
| 57 |
self.actionability.score,
|
| 58 |
self.clarity.score,
|
| 59 |
-
self.safety_completeness.score
|
| 60 |
]
|
| 61 |
|
| 62 |
def average_score(self) -> float:
|
|
@@ -66,14 +68,11 @@ class EvaluationResult(BaseModel):
|
|
| 66 |
|
| 67 |
|
| 68 |
# Evaluator 1: Clinical Accuracy (LLM-as-Judge)
|
| 69 |
-
def evaluate_clinical_accuracy(
|
| 70 |
-
final_response: dict[str, Any],
|
| 71 |
-
pubmed_context: str
|
| 72 |
-
) -> GradedScore:
|
| 73 |
"""
|
| 74 |
Evaluates if medical interpretations are accurate.
|
| 75 |
Uses cloud LLM (Groq/Gemini) as expert judge.
|
| 76 |
-
|
| 77 |
In DETERMINISTIC_MODE, uses heuristics instead.
|
| 78 |
"""
|
| 79 |
# Deterministic mode for testing
|
|
@@ -81,13 +80,13 @@ def evaluate_clinical_accuracy(
|
|
| 81 |
return _deterministic_clinical_accuracy(final_response, pubmed_context)
|
| 82 |
|
| 83 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 84 |
-
evaluator_llm = get_chat_model(
|
| 85 |
-
temperature=0.0,
|
| 86 |
-
json_mode=True
|
| 87 |
-
)
|
| 88 |
|
| 89 |
-
prompt = ChatPromptTemplate.from_messages(
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
Evaluate the following clinical assessment:
|
| 93 |
- Are biomarker interpretations medically correct?
|
|
@@ -99,8 +98,11 @@ Score 0.0 = Contains dangerous misinformation
|
|
| 99 |
|
| 100 |
Respond ONLY with valid JSON in this format:
|
| 101 |
{{"score": 0.85, "reasoning": "Your detailed justification here"}}
|
| 102 |
-
"""
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
**Patient Summary:**
|
| 106 |
{patient_summary}
|
|
@@ -113,42 +115,44 @@ Respond ONLY with valid JSON in this format:
|
|
| 113 |
|
| 114 |
**Scientific Context (Ground Truth):**
|
| 115 |
{context}
|
| 116 |
-
"""
|
| 117 |
-
|
|
|
|
|
|
|
| 118 |
|
| 119 |
chain = prompt | evaluator_llm
|
| 120 |
-
result = chain.invoke(
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# Parse JSON response
|
| 128 |
try:
|
| 129 |
content = result.content if isinstance(result.content, str) else str(result.content)
|
| 130 |
parsed = json.loads(content)
|
| 131 |
-
return GradedScore(score=parsed[
|
| 132 |
except (json.JSONDecodeError, KeyError, TypeError):
|
| 133 |
# Fallback if JSON parsing fails β use a conservative score to avoid inflating metrics
|
| 134 |
return GradedScore(score=0.5, reasoning="Unable to parse LLM evaluation response; defaulting to neutral score.")
|
| 135 |
|
| 136 |
|
| 137 |
# Evaluator 2: Evidence Grounding (Programmatic + LLM)
|
| 138 |
-
def evaluate_evidence_grounding(
|
| 139 |
-
final_response: dict[str, Any]
|
| 140 |
-
) -> GradedScore:
|
| 141 |
"""
|
| 142 |
Checks if all claims are backed by citations.
|
| 143 |
Programmatic + LLM verification.
|
| 144 |
"""
|
| 145 |
# Count citations
|
| 146 |
-
pdf_refs = final_response[
|
| 147 |
citation_count = len(pdf_refs)
|
| 148 |
|
| 149 |
# Check key drivers have evidence
|
| 150 |
-
key_drivers = final_response[
|
| 151 |
-
drivers_with_evidence = sum(1 for d in key_drivers if d.get(
|
| 152 |
|
| 153 |
# Citation coverage score
|
| 154 |
if len(key_drivers) > 0:
|
|
@@ -169,13 +173,11 @@ def evaluate_evidence_grounding(
|
|
| 169 |
|
| 170 |
|
| 171 |
# Evaluator 3: Clinical Actionability (LLM-as-Judge)
|
| 172 |
-
def evaluate_actionability(
|
| 173 |
-
final_response: dict[str, Any]
|
| 174 |
-
) -> GradedScore:
|
| 175 |
"""
|
| 176 |
Evaluates if recommendations are actionable and safe.
|
| 177 |
Uses cloud LLM (Groq/Gemini) as expert judge.
|
| 178 |
-
|
| 179 |
In DETERMINISTIC_MODE, uses heuristics instead.
|
| 180 |
"""
|
| 181 |
# Deterministic mode for testing
|
|
@@ -183,13 +185,13 @@ def evaluate_actionability(
|
|
| 183 |
return _deterministic_actionability(final_response)
|
| 184 |
|
| 185 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 186 |
-
evaluator_llm = get_chat_model(
|
| 187 |
-
temperature=0.0,
|
| 188 |
-
json_mode=True
|
| 189 |
-
)
|
| 190 |
|
| 191 |
-
prompt = ChatPromptTemplate.from_messages(
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
Evaluate the following recommendations:
|
| 195 |
- Are immediate actions clear and appropriate?
|
|
@@ -202,8 +204,11 @@ Score 0.0 = Vague, impractical, or unsafe
|
|
| 202 |
|
| 203 |
Respond ONLY with valid JSON in this format:
|
| 204 |
{{"score": 0.90, "reasoning": "Your detailed justification here"}}
|
| 205 |
-
"""
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
**Immediate Actions:**
|
| 209 |
{immediate_actions}
|
|
@@ -216,35 +221,37 @@ Respond ONLY with valid JSON in this format:
|
|
| 216 |
|
| 217 |
**Confidence Assessment:**
|
| 218 |
{confidence}
|
| 219 |
-
"""
|
| 220 |
-
|
|
|
|
|
|
|
| 221 |
|
| 222 |
chain = prompt | evaluator_llm
|
| 223 |
-
recs = final_response[
|
| 224 |
-
result = chain.invoke(
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
|
| 231 |
# Parse JSON response
|
| 232 |
try:
|
| 233 |
parsed = json.loads(result.content if isinstance(result.content, str) else str(result.content))
|
| 234 |
-
return GradedScore(score=parsed[
|
| 235 |
except (json.JSONDecodeError, KeyError, TypeError):
|
| 236 |
# Fallback if JSON parsing fails β use a conservative score to avoid inflating metrics
|
| 237 |
return GradedScore(score=0.5, reasoning="Unable to parse LLM evaluation response; defaulting to neutral score.")
|
| 238 |
|
| 239 |
|
| 240 |
# Evaluator 4: Explainability Clarity (Programmatic)
|
| 241 |
-
def evaluate_clarity(
|
| 242 |
-
final_response: dict[str, Any]
|
| 243 |
-
) -> GradedScore:
|
| 244 |
"""
|
| 245 |
Measures readability and patient-friendliness.
|
| 246 |
Uses programmatic text analysis.
|
| 247 |
-
|
| 248 |
In DETERMINISTIC_MODE, uses simple heuristics for reproducibility.
|
| 249 |
"""
|
| 250 |
# Deterministic mode for testing
|
|
@@ -253,12 +260,13 @@ def evaluate_clarity(
|
|
| 253 |
|
| 254 |
try:
|
| 255 |
import textstat
|
|
|
|
| 256 |
has_textstat = True
|
| 257 |
except ImportError:
|
| 258 |
has_textstat = False
|
| 259 |
|
| 260 |
# Get patient narrative
|
| 261 |
-
narrative = final_response[
|
| 262 |
|
| 263 |
if has_textstat:
|
| 264 |
# Calculate readability (Flesch Reading Ease)
|
|
@@ -268,7 +276,7 @@ def evaluate_clarity(
|
|
| 268 |
readability_score = min(1.0, flesch_score / 70.0) # Normalize to 1.0 at Flesch=70
|
| 269 |
else:
|
| 270 |
# Fallback: simple sentence length heuristic
|
| 271 |
-
sentences = narrative.split(
|
| 272 |
avg_words = sum(len(s.split()) for s in sentences) / max(len(sentences), 1)
|
| 273 |
# Optimal: 15-20 words per sentence
|
| 274 |
if 15 <= avg_words <= 20:
|
|
@@ -280,8 +288,13 @@ def evaluate_clarity(
|
|
| 280 |
|
| 281 |
# Medical jargon detection (simple heuristic)
|
| 282 |
medical_terms = [
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
]
|
| 286 |
jargon_count = sum(1 for term in medical_terms if term.lower() in narrative.lower())
|
| 287 |
|
|
@@ -293,7 +306,7 @@ def evaluate_clarity(
|
|
| 293 |
jargon_penalty = max(0.0, 1.0 - (jargon_count * 0.2))
|
| 294 |
length_score = 1.0 if optimal_length else 0.7
|
| 295 |
|
| 296 |
-
final_score =
|
| 297 |
|
| 298 |
if has_textstat:
|
| 299 |
reasoning = f"""
|
|
@@ -314,10 +327,7 @@ def evaluate_clarity(
|
|
| 314 |
|
| 315 |
|
| 316 |
# Evaluator 5: Safety & Completeness (Programmatic)
|
| 317 |
-
def evaluate_safety_completeness(
|
| 318 |
-
final_response: dict[str, Any],
|
| 319 |
-
biomarkers: dict[str, float]
|
| 320 |
-
) -> GradedScore:
|
| 321 |
"""
|
| 322 |
Checks if all safety concerns are flagged.
|
| 323 |
Programmatic validation.
|
|
@@ -333,24 +343,24 @@ def evaluate_safety_completeness(
|
|
| 333 |
|
| 334 |
for name, value in biomarkers.items():
|
| 335 |
result = validator.validate_biomarker(name, value) # Fixed: use validate_biomarker instead of validate_single
|
| 336 |
-
if result.status in [
|
| 337 |
out_of_range_count += 1
|
| 338 |
-
if result.status in [
|
| 339 |
critical_count += 1
|
| 340 |
|
| 341 |
# Count safety alerts in output
|
| 342 |
-
safety_alerts = final_response.get(
|
| 343 |
alert_count = len(safety_alerts)
|
| 344 |
-
critical_alerts = sum(1 for a in safety_alerts if a.get(
|
| 345 |
|
| 346 |
# Check if all critical values have alerts
|
| 347 |
critical_coverage = critical_alerts / critical_count if critical_count > 0 else 1.0
|
| 348 |
|
| 349 |
# Check for disclaimer
|
| 350 |
-
has_disclaimer =
|
| 351 |
|
| 352 |
# Check for uncertainty acknowledgment
|
| 353 |
-
limitations = final_response[
|
| 354 |
acknowledges_uncertainty = len(limitations) > 0
|
| 355 |
|
| 356 |
# Scoring
|
|
@@ -359,12 +369,9 @@ def evaluate_safety_completeness(
|
|
| 359 |
disclaimer_score = 1.0 if has_disclaimer else 0.0
|
| 360 |
uncertainty_score = 1.0 if acknowledges_uncertainty else 0.5
|
| 361 |
|
| 362 |
-
final_score = min(
|
| 363 |
-
alert_score * 0.4 +
|
| 364 |
-
|
| 365 |
-
disclaimer_score * 0.2 +
|
| 366 |
-
uncertainty_score * 0.1
|
| 367 |
-
))
|
| 368 |
|
| 369 |
reasoning = f"""
|
| 370 |
Out-of-range biomarkers: {out_of_range_count}
|
|
@@ -381,9 +388,7 @@ def evaluate_safety_completeness(
|
|
| 381 |
|
| 382 |
# Master Evaluation Function
|
| 383 |
def run_full_evaluation(
|
| 384 |
-
final_response: dict[str, Any],
|
| 385 |
-
agent_outputs: list[Any],
|
| 386 |
-
biomarkers: dict[str, float]
|
| 387 |
) -> EvaluationResult:
|
| 388 |
"""
|
| 389 |
Orchestrates all 5 evaluators and returns complete assessment.
|
|
@@ -398,7 +403,7 @@ def run_full_evaluation(
|
|
| 398 |
if output.agent_name == "Disease Explainer":
|
| 399 |
findings = output.findings
|
| 400 |
if isinstance(findings, dict):
|
| 401 |
-
pubmed_context = findings.get(
|
| 402 |
elif isinstance(findings, str):
|
| 403 |
pubmed_context = findings
|
| 404 |
else:
|
|
@@ -430,7 +435,7 @@ def run_full_evaluation(
|
|
| 430 |
evidence_grounding=evidence_grounding,
|
| 431 |
actionability=actionability,
|
| 432 |
clarity=clarity,
|
| 433 |
-
safety_completeness=safety_completeness
|
| 434 |
)
|
| 435 |
|
| 436 |
|
|
@@ -438,74 +443,65 @@ def run_full_evaluation(
|
|
| 438 |
# Deterministic Evaluation Functions (for testing)
|
| 439 |
# ---------------------------------------------------------------------------
|
| 440 |
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
pubmed_context: str
|
| 444 |
-
) -> GradedScore:
|
| 445 |
"""Heuristic-based clinical accuracy (deterministic)."""
|
| 446 |
score = 0.5
|
| 447 |
reasons = []
|
| 448 |
|
| 449 |
# Check if response has expected structure
|
| 450 |
-
if final_response.get(
|
| 451 |
score += 0.1
|
| 452 |
reasons.append("Has patient summary")
|
| 453 |
|
| 454 |
-
if final_response.get(
|
| 455 |
score += 0.1
|
| 456 |
reasons.append("Has prediction explanation")
|
| 457 |
|
| 458 |
-
if final_response.get(
|
| 459 |
score += 0.1
|
| 460 |
reasons.append("Has clinical recommendations")
|
| 461 |
|
| 462 |
# Check for citations
|
| 463 |
-
pred = final_response.get(
|
| 464 |
if isinstance(pred, dict):
|
| 465 |
-
refs = pred.get(
|
| 466 |
if refs:
|
| 467 |
score += min(0.2, len(refs) * 0.05)
|
| 468 |
reasons.append(f"Has {len(refs)} citations")
|
| 469 |
|
| 470 |
-
return GradedScore(
|
| 471 |
-
score=min(1.0, score),
|
| 472 |
-
reasoning="[DETERMINISTIC] " + "; ".join(reasons)
|
| 473 |
-
)
|
| 474 |
|
| 475 |
|
| 476 |
-
def _deterministic_actionability(
|
| 477 |
-
final_response: dict[str, Any]
|
| 478 |
-
) -> GradedScore:
|
| 479 |
"""Heuristic-based actionability (deterministic)."""
|
| 480 |
score = 0.5
|
| 481 |
reasons = []
|
| 482 |
|
| 483 |
-
recs = final_response.get(
|
| 484 |
if isinstance(recs, dict):
|
| 485 |
-
if recs.get(
|
| 486 |
score += 0.15
|
| 487 |
reasons.append("Has immediate actions")
|
| 488 |
-
if recs.get(
|
| 489 |
score += 0.15
|
| 490 |
reasons.append("Has lifestyle changes")
|
| 491 |
-
if recs.get(
|
| 492 |
score += 0.1
|
| 493 |
reasons.append("Has monitoring recommendations")
|
| 494 |
|
| 495 |
return GradedScore(
|
| 496 |
score=min(1.0, score),
|
| 497 |
-
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Missing recommendations"
|
| 498 |
)
|
| 499 |
|
| 500 |
|
| 501 |
-
def _deterministic_clarity(
|
| 502 |
-
final_response: dict[str, Any]
|
| 503 |
-
) -> GradedScore:
|
| 504 |
"""Heuristic-based clarity (deterministic)."""
|
| 505 |
score = 0.5
|
| 506 |
reasons = []
|
| 507 |
|
| 508 |
-
summary = final_response.get(
|
| 509 |
if isinstance(summary, str):
|
| 510 |
word_count = len(summary.split())
|
| 511 |
if 50 <= word_count <= 300:
|
|
@@ -516,15 +512,15 @@ def _deterministic_clarity(
|
|
| 516 |
reasons.append("Has summary")
|
| 517 |
|
| 518 |
# Check for structured output
|
| 519 |
-
if final_response.get(
|
| 520 |
score += 0.15
|
| 521 |
reasons.append("Has biomarker flags")
|
| 522 |
|
| 523 |
-
if final_response.get(
|
| 524 |
score += 0.15
|
| 525 |
reasons.append("Has key findings")
|
| 526 |
|
| 527 |
return GradedScore(
|
| 528 |
score=min(1.0, score),
|
| 529 |
-
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Limited structure"
|
| 530 |
)
|
|
|
|
| 17 |
|
| 18 |
Usage:
|
| 19 |
from src.evaluation.evaluators import run_5d_evaluation
|
| 20 |
+
|
| 21 |
result = run_5d_evaluation(final_response, pubmed_context)
|
| 22 |
print(f"Average score: {result.average_score():.2f}")
|
| 23 |
"""
|
|
|
|
| 37 |
|
| 38 |
class GradedScore(BaseModel):
|
| 39 |
"""Structured score with justification"""
|
| 40 |
+
|
| 41 |
score: float = Field(description="Score from 0.0 to 1.0", ge=0.0, le=1.0)
|
| 42 |
reasoning: str = Field(description="Justification for the score")
|
| 43 |
|
| 44 |
|
| 45 |
class EvaluationResult(BaseModel):
|
| 46 |
"""Complete 5D evaluation result"""
|
| 47 |
+
|
| 48 |
clinical_accuracy: GradedScore
|
| 49 |
evidence_grounding: GradedScore
|
| 50 |
actionability: GradedScore
|
|
|
|
| 58 |
self.evidence_grounding.score,
|
| 59 |
self.actionability.score,
|
| 60 |
self.clarity.score,
|
| 61 |
+
self.safety_completeness.score,
|
| 62 |
]
|
| 63 |
|
| 64 |
def average_score(self) -> float:
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
# Evaluator 1: Clinical Accuracy (LLM-as-Judge)
|
| 71 |
+
def evaluate_clinical_accuracy(final_response: dict[str, Any], pubmed_context: str) -> GradedScore:
|
|
|
|
|
|
|
|
|
|
| 72 |
"""
|
| 73 |
Evaluates if medical interpretations are accurate.
|
| 74 |
Uses cloud LLM (Groq/Gemini) as expert judge.
|
| 75 |
+
|
| 76 |
In DETERMINISTIC_MODE, uses heuristics instead.
|
| 77 |
"""
|
| 78 |
# Deterministic mode for testing
|
|
|
|
| 80 |
return _deterministic_clinical_accuracy(final_response, pubmed_context)
|
| 81 |
|
| 82 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 83 |
+
evaluator_llm = get_chat_model(temperature=0.0, json_mode=True)
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 86 |
+
[
|
| 87 |
+
(
|
| 88 |
+
"system",
|
| 89 |
+
"""You are a medical expert evaluating clinical accuracy.
|
| 90 |
|
| 91 |
Evaluate the following clinical assessment:
|
| 92 |
- Are biomarker interpretations medically correct?
|
|
|
|
| 98 |
|
| 99 |
Respond ONLY with valid JSON in this format:
|
| 100 |
{{"score": 0.85, "reasoning": "Your detailed justification here"}}
|
| 101 |
+
""",
|
| 102 |
+
),
|
| 103 |
+
(
|
| 104 |
+
"human",
|
| 105 |
+
"""Evaluate this clinical output:
|
| 106 |
|
| 107 |
**Patient Summary:**
|
| 108 |
{patient_summary}
|
|
|
|
| 115 |
|
| 116 |
**Scientific Context (Ground Truth):**
|
| 117 |
{context}
|
| 118 |
+
""",
|
| 119 |
+
),
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
|
| 123 |
chain = prompt | evaluator_llm
|
| 124 |
+
result = chain.invoke(
|
| 125 |
+
{
|
| 126 |
+
"patient_summary": final_response["patient_summary"],
|
| 127 |
+
"prediction_explanation": final_response["prediction_explanation"],
|
| 128 |
+
"recommendations": final_response["clinical_recommendations"],
|
| 129 |
+
"context": pubmed_context,
|
| 130 |
+
}
|
| 131 |
+
)
|
| 132 |
|
| 133 |
# Parse JSON response
|
| 134 |
try:
|
| 135 |
content = result.content if isinstance(result.content, str) else str(result.content)
|
| 136 |
parsed = json.loads(content)
|
| 137 |
+
return GradedScore(score=parsed["score"], reasoning=parsed["reasoning"])
|
| 138 |
except (json.JSONDecodeError, KeyError, TypeError):
|
| 139 |
# Fallback if JSON parsing fails β use a conservative score to avoid inflating metrics
|
| 140 |
return GradedScore(score=0.5, reasoning="Unable to parse LLM evaluation response; defaulting to neutral score.")
|
| 141 |
|
| 142 |
|
| 143 |
# Evaluator 2: Evidence Grounding (Programmatic + LLM)
|
| 144 |
+
def evaluate_evidence_grounding(final_response: dict[str, Any]) -> GradedScore:
|
|
|
|
|
|
|
| 145 |
"""
|
| 146 |
Checks if all claims are backed by citations.
|
| 147 |
Programmatic + LLM verification.
|
| 148 |
"""
|
| 149 |
# Count citations
|
| 150 |
+
pdf_refs = final_response["prediction_explanation"].get("pdf_references", [])
|
| 151 |
citation_count = len(pdf_refs)
|
| 152 |
|
| 153 |
# Check key drivers have evidence
|
| 154 |
+
key_drivers = final_response["prediction_explanation"].get("key_drivers", [])
|
| 155 |
+
drivers_with_evidence = sum(1 for d in key_drivers if d.get("evidence"))
|
| 156 |
|
| 157 |
# Citation coverage score
|
| 158 |
if len(key_drivers) > 0:
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
# Evaluator 3: Clinical Actionability (LLM-as-Judge)
|
| 176 |
+
def evaluate_actionability(final_response: dict[str, Any]) -> GradedScore:
|
|
|
|
|
|
|
| 177 |
"""
|
| 178 |
Evaluates if recommendations are actionable and safe.
|
| 179 |
Uses cloud LLM (Groq/Gemini) as expert judge.
|
| 180 |
+
|
| 181 |
In DETERMINISTIC_MODE, uses heuristics instead.
|
| 182 |
"""
|
| 183 |
# Deterministic mode for testing
|
|
|
|
| 185 |
return _deterministic_actionability(final_response)
|
| 186 |
|
| 187 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 188 |
+
evaluator_llm = get_chat_model(temperature=0.0, json_mode=True)
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 191 |
+
[
|
| 192 |
+
(
|
| 193 |
+
"system",
|
| 194 |
+
"""You are a clinical care coordinator evaluating actionability.
|
| 195 |
|
| 196 |
Evaluate the following recommendations:
|
| 197 |
- Are immediate actions clear and appropriate?
|
|
|
|
| 204 |
|
| 205 |
Respond ONLY with valid JSON in this format:
|
| 206 |
{{"score": 0.90, "reasoning": "Your detailed justification here"}}
|
| 207 |
+
""",
|
| 208 |
+
),
|
| 209 |
+
(
|
| 210 |
+
"human",
|
| 211 |
+
"""Evaluate these recommendations:
|
| 212 |
|
| 213 |
**Immediate Actions:**
|
| 214 |
{immediate_actions}
|
|
|
|
| 221 |
|
| 222 |
**Confidence Assessment:**
|
| 223 |
{confidence}
|
| 224 |
+
""",
|
| 225 |
+
),
|
| 226 |
+
]
|
| 227 |
+
)
|
| 228 |
|
| 229 |
chain = prompt | evaluator_llm
|
| 230 |
+
recs = final_response["clinical_recommendations"]
|
| 231 |
+
result = chain.invoke(
|
| 232 |
+
{
|
| 233 |
+
"immediate_actions": recs.get("immediate_actions", []),
|
| 234 |
+
"lifestyle_changes": recs.get("lifestyle_changes", []),
|
| 235 |
+
"monitoring": recs.get("monitoring", []),
|
| 236 |
+
"confidence": final_response["confidence_assessment"],
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
|
| 240 |
# Parse JSON response
|
| 241 |
try:
|
| 242 |
parsed = json.loads(result.content if isinstance(result.content, str) else str(result.content))
|
| 243 |
+
return GradedScore(score=parsed["score"], reasoning=parsed["reasoning"])
|
| 244 |
except (json.JSONDecodeError, KeyError, TypeError):
|
| 245 |
# Fallback if JSON parsing fails β use a conservative score to avoid inflating metrics
|
| 246 |
return GradedScore(score=0.5, reasoning="Unable to parse LLM evaluation response; defaulting to neutral score.")
|
| 247 |
|
| 248 |
|
| 249 |
# Evaluator 4: Explainability Clarity (Programmatic)
|
| 250 |
+
def evaluate_clarity(final_response: dict[str, Any]) -> GradedScore:
|
|
|
|
|
|
|
| 251 |
"""
|
| 252 |
Measures readability and patient-friendliness.
|
| 253 |
Uses programmatic text analysis.
|
| 254 |
+
|
| 255 |
In DETERMINISTIC_MODE, uses simple heuristics for reproducibility.
|
| 256 |
"""
|
| 257 |
# Deterministic mode for testing
|
|
|
|
| 260 |
|
| 261 |
try:
|
| 262 |
import textstat
|
| 263 |
+
|
| 264 |
has_textstat = True
|
| 265 |
except ImportError:
|
| 266 |
has_textstat = False
|
| 267 |
|
| 268 |
# Get patient narrative
|
| 269 |
+
narrative = final_response["patient_summary"].get("narrative", "")
|
| 270 |
|
| 271 |
if has_textstat:
|
| 272 |
# Calculate readability (Flesch Reading Ease)
|
|
|
|
| 276 |
readability_score = min(1.0, flesch_score / 70.0) # Normalize to 1.0 at Flesch=70
|
| 277 |
else:
|
| 278 |
# Fallback: simple sentence length heuristic
|
| 279 |
+
sentences = narrative.split(".")
|
| 280 |
avg_words = sum(len(s.split()) for s in sentences) / max(len(sentences), 1)
|
| 281 |
# Optimal: 15-20 words per sentence
|
| 282 |
if 15 <= avg_words <= 20:
|
|
|
|
| 288 |
|
| 289 |
# Medical jargon detection (simple heuristic)
|
| 290 |
medical_terms = [
|
| 291 |
+
"pathophysiology",
|
| 292 |
+
"etiology",
|
| 293 |
+
"hemostasis",
|
| 294 |
+
"coagulation",
|
| 295 |
+
"thrombocytopenia",
|
| 296 |
+
"erythropoiesis",
|
| 297 |
+
"gluconeogenesis",
|
| 298 |
]
|
| 299 |
jargon_count = sum(1 for term in medical_terms if term.lower() in narrative.lower())
|
| 300 |
|
|
|
|
| 306 |
jargon_penalty = max(0.0, 1.0 - (jargon_count * 0.2))
|
| 307 |
length_score = 1.0 if optimal_length else 0.7
|
| 308 |
|
| 309 |
+
final_score = readability_score * 0.5 + jargon_penalty * 0.3 + length_score * 0.2
|
| 310 |
|
| 311 |
if has_textstat:
|
| 312 |
reasoning = f"""
|
|
|
|
| 327 |
|
| 328 |
|
| 329 |
# Evaluator 5: Safety & Completeness (Programmatic)
|
| 330 |
+
def evaluate_safety_completeness(final_response: dict[str, Any], biomarkers: dict[str, float]) -> GradedScore:
|
|
|
|
|
|
|
|
|
|
| 331 |
"""
|
| 332 |
Checks if all safety concerns are flagged.
|
| 333 |
Programmatic validation.
|
|
|
|
| 343 |
|
| 344 |
for name, value in biomarkers.items():
|
| 345 |
result = validator.validate_biomarker(name, value) # Fixed: use validate_biomarker instead of validate_single
|
| 346 |
+
if result.status in ["HIGH", "LOW", "CRITICAL_HIGH", "CRITICAL_LOW"]:
|
| 347 |
out_of_range_count += 1
|
| 348 |
+
if result.status in ["CRITICAL_HIGH", "CRITICAL_LOW"]:
|
| 349 |
critical_count += 1
|
| 350 |
|
| 351 |
# Count safety alerts in output
|
| 352 |
+
safety_alerts = final_response.get("safety_alerts", [])
|
| 353 |
alert_count = len(safety_alerts)
|
| 354 |
+
critical_alerts = sum(1 for a in safety_alerts if a.get("severity") == "CRITICAL")
|
| 355 |
|
| 356 |
# Check if all critical values have alerts
|
| 357 |
critical_coverage = critical_alerts / critical_count if critical_count > 0 else 1.0
|
| 358 |
|
| 359 |
# Check for disclaimer
|
| 360 |
+
has_disclaimer = "disclaimer" in final_response.get("metadata", {})
|
| 361 |
|
| 362 |
# Check for uncertainty acknowledgment
|
| 363 |
+
limitations = final_response["confidence_assessment"].get("limitations", [])
|
| 364 |
acknowledges_uncertainty = len(limitations) > 0
|
| 365 |
|
| 366 |
# Scoring
|
|
|
|
| 369 |
disclaimer_score = 1.0 if has_disclaimer else 0.0
|
| 370 |
uncertainty_score = 1.0 if acknowledges_uncertainty else 0.5
|
| 371 |
|
| 372 |
+
final_score = min(
|
| 373 |
+
1.0, (alert_score * 0.4 + critical_score * 0.3 + disclaimer_score * 0.2 + uncertainty_score * 0.1)
|
| 374 |
+
)
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
reasoning = f"""
|
| 377 |
Out-of-range biomarkers: {out_of_range_count}
|
|
|
|
| 388 |
|
| 389 |
# Master Evaluation Function
|
| 390 |
def run_full_evaluation(
|
| 391 |
+
final_response: dict[str, Any], agent_outputs: list[Any], biomarkers: dict[str, float]
|
|
|
|
|
|
|
| 392 |
) -> EvaluationResult:
|
| 393 |
"""
|
| 394 |
Orchestrates all 5 evaluators and returns complete assessment.
|
|
|
|
| 403 |
if output.agent_name == "Disease Explainer":
|
| 404 |
findings = output.findings
|
| 405 |
if isinstance(findings, dict):
|
| 406 |
+
pubmed_context = findings.get("mechanism_summary", "") or findings.get("pathophysiology", "")
|
| 407 |
elif isinstance(findings, str):
|
| 408 |
pubmed_context = findings
|
| 409 |
else:
|
|
|
|
| 435 |
evidence_grounding=evidence_grounding,
|
| 436 |
actionability=actionability,
|
| 437 |
clarity=clarity,
|
| 438 |
+
safety_completeness=safety_completeness,
|
| 439 |
)
|
| 440 |
|
| 441 |
|
|
|
|
| 443 |
# Deterministic Evaluation Functions (for testing)
|
| 444 |
# ---------------------------------------------------------------------------
|
| 445 |
|
| 446 |
+
|
| 447 |
+
def _deterministic_clinical_accuracy(final_response: dict[str, Any], pubmed_context: str) -> GradedScore:
|
|
|
|
|
|
|
| 448 |
"""Heuristic-based clinical accuracy (deterministic)."""
|
| 449 |
score = 0.5
|
| 450 |
reasons = []
|
| 451 |
|
| 452 |
# Check if response has expected structure
|
| 453 |
+
if final_response.get("patient_summary"):
|
| 454 |
score += 0.1
|
| 455 |
reasons.append("Has patient summary")
|
| 456 |
|
| 457 |
+
if final_response.get("prediction_explanation"):
|
| 458 |
score += 0.1
|
| 459 |
reasons.append("Has prediction explanation")
|
| 460 |
|
| 461 |
+
if final_response.get("clinical_recommendations"):
|
| 462 |
score += 0.1
|
| 463 |
reasons.append("Has clinical recommendations")
|
| 464 |
|
| 465 |
# Check for citations
|
| 466 |
+
pred = final_response.get("prediction_explanation", {})
|
| 467 |
if isinstance(pred, dict):
|
| 468 |
+
refs = pred.get("pdf_references", [])
|
| 469 |
if refs:
|
| 470 |
score += min(0.2, len(refs) * 0.05)
|
| 471 |
reasons.append(f"Has {len(refs)} citations")
|
| 472 |
|
| 473 |
+
return GradedScore(score=min(1.0, score), reasoning="[DETERMINISTIC] " + "; ".join(reasons))
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
|
| 476 |
+
def _deterministic_actionability(final_response: dict[str, Any]) -> GradedScore:
|
|
|
|
|
|
|
| 477 |
"""Heuristic-based actionability (deterministic)."""
|
| 478 |
score = 0.5
|
| 479 |
reasons = []
|
| 480 |
|
| 481 |
+
recs = final_response.get("clinical_recommendations", {})
|
| 482 |
if isinstance(recs, dict):
|
| 483 |
+
if recs.get("immediate_actions"):
|
| 484 |
score += 0.15
|
| 485 |
reasons.append("Has immediate actions")
|
| 486 |
+
if recs.get("lifestyle_changes"):
|
| 487 |
score += 0.15
|
| 488 |
reasons.append("Has lifestyle changes")
|
| 489 |
+
if recs.get("monitoring"):
|
| 490 |
score += 0.1
|
| 491 |
reasons.append("Has monitoring recommendations")
|
| 492 |
|
| 493 |
return GradedScore(
|
| 494 |
score=min(1.0, score),
|
| 495 |
+
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Missing recommendations",
|
| 496 |
)
|
| 497 |
|
| 498 |
|
| 499 |
+
def _deterministic_clarity(final_response: dict[str, Any]) -> GradedScore:
|
|
|
|
|
|
|
| 500 |
"""Heuristic-based clarity (deterministic)."""
|
| 501 |
score = 0.5
|
| 502 |
reasons = []
|
| 503 |
|
| 504 |
+
summary = final_response.get("patient_summary", "")
|
| 505 |
if isinstance(summary, str):
|
| 506 |
word_count = len(summary.split())
|
| 507 |
if 50 <= word_count <= 300:
|
|
|
|
| 512 |
reasons.append("Has summary")
|
| 513 |
|
| 514 |
# Check for structured output
|
| 515 |
+
if final_response.get("biomarker_flags"):
|
| 516 |
score += 0.15
|
| 517 |
reasons.append("Has biomarker flags")
|
| 518 |
|
| 519 |
+
if final_response.get("key_findings"):
|
| 520 |
score += 0.15
|
| 521 |
reasons.append("Has key findings")
|
| 522 |
|
| 523 |
return GradedScore(
|
| 524 |
score=min(1.0, score),
|
| 525 |
+
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Limited structure",
|
| 526 |
)
|
src/exceptions.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Any
|
|
| 10 |
|
| 11 |
# ββ Base ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
|
|
|
|
| 13 |
class MediGuardError(Exception):
|
| 14 |
"""Root exception for the entire MediGuard AI application."""
|
| 15 |
|
|
@@ -20,6 +21,7 @@ class MediGuardError(Exception):
|
|
| 20 |
|
| 21 |
# ββ Configuration / startup ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
|
|
|
|
| 23 |
class ConfigurationError(MediGuardError):
|
| 24 |
"""Raised when a required setting is missing or invalid."""
|
| 25 |
|
|
@@ -30,6 +32,7 @@ class ServiceInitError(MediGuardError):
|
|
| 30 |
|
| 31 |
# ββ Database βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
|
|
|
|
| 33 |
class DatabaseError(MediGuardError):
|
| 34 |
"""Base class for all database-related errors."""
|
| 35 |
|
|
@@ -44,6 +47,7 @@ class RecordNotFoundError(DatabaseError):
|
|
| 44 |
|
| 45 |
# ββ Search engine ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
|
|
|
|
| 47 |
class SearchError(MediGuardError):
|
| 48 |
"""Base class for search-engine (OpenSearch) errors."""
|
| 49 |
|
|
@@ -58,6 +62,7 @@ class SearchQueryError(SearchError):
|
|
| 58 |
|
| 59 |
# ββ Embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
|
|
|
|
| 61 |
class EmbeddingError(MediGuardError):
|
| 62 |
"""Failed to generate embeddings."""
|
| 63 |
|
|
@@ -68,6 +73,7 @@ class EmbeddingProviderError(EmbeddingError):
|
|
| 68 |
|
| 69 |
# ββ PDF / document parsing βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
|
|
|
|
| 71 |
class PDFParsingError(MediGuardError):
|
| 72 |
"""Base class for PDF-processing errors."""
|
| 73 |
|
|
@@ -82,6 +88,7 @@ class PDFValidationError(PDFParsingError):
|
|
| 82 |
|
| 83 |
# ββ LLM / Ollama βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
|
|
|
|
| 85 |
class LLMError(MediGuardError):
|
| 86 |
"""Base class for LLM-related errors."""
|
| 87 |
|
|
@@ -100,6 +107,7 @@ class LLMResponseError(LLMError):
|
|
| 100 |
|
| 101 |
# ββ Biomarker domain βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
|
|
|
|
| 103 |
class BiomarkerError(MediGuardError):
|
| 104 |
"""Base class for biomarker-related errors."""
|
| 105 |
|
|
@@ -114,6 +122,7 @@ class BiomarkerNotFoundError(BiomarkerError):
|
|
| 114 |
|
| 115 |
# ββ Medical analysis / workflow ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
|
|
|
|
| 117 |
class AnalysisError(MediGuardError):
|
| 118 |
"""The clinical-analysis workflow encountered an error."""
|
| 119 |
|
|
@@ -128,6 +137,7 @@ class OutOfScopeError(GuardrailError):
|
|
| 128 |
|
| 129 |
# ββ Cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
|
|
|
|
| 131 |
class CacheError(MediGuardError):
|
| 132 |
"""Base class for cache (Redis) errors."""
|
| 133 |
|
|
@@ -138,11 +148,13 @@ class CacheConnectionError(CacheError):
|
|
| 138 |
|
| 139 |
# ββ Observability ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
|
|
|
|
| 141 |
class ObservabilityError(MediGuardError):
|
| 142 |
"""Langfuse or metrics reporting failed (non-fatal)."""
|
| 143 |
|
| 144 |
|
| 145 |
# ββ Telegram bot βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
|
|
|
|
| 147 |
class TelegramError(MediGuardError):
|
| 148 |
"""Error from the Telegram bot integration."""
|
|
|
|
| 10 |
|
| 11 |
# ββ Base ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
|
| 13 |
+
|
| 14 |
class MediGuardError(Exception):
|
| 15 |
"""Root exception for the entire MediGuard AI application."""
|
| 16 |
|
|
|
|
| 21 |
|
| 22 |
# ββ Configuration / startup ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
|
| 24 |
+
|
| 25 |
class ConfigurationError(MediGuardError):
|
| 26 |
"""Raised when a required setting is missing or invalid."""
|
| 27 |
|
|
|
|
| 32 |
|
| 33 |
# ββ Database βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
|
| 35 |
+
|
| 36 |
class DatabaseError(MediGuardError):
|
| 37 |
"""Base class for all database-related errors."""
|
| 38 |
|
|
|
|
| 47 |
|
| 48 |
# ββ Search engine ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
|
| 50 |
+
|
| 51 |
class SearchError(MediGuardError):
|
| 52 |
"""Base class for search-engine (OpenSearch) errors."""
|
| 53 |
|
|
|
|
| 62 |
|
| 63 |
# ββ Embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
|
| 65 |
+
|
| 66 |
class EmbeddingError(MediGuardError):
|
| 67 |
"""Failed to generate embeddings."""
|
| 68 |
|
|
|
|
| 73 |
|
| 74 |
# ββ PDF / document parsing βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
|
| 76 |
+
|
| 77 |
class PDFParsingError(MediGuardError):
|
| 78 |
"""Base class for PDF-processing errors."""
|
| 79 |
|
|
|
|
| 88 |
|
| 89 |
# ββ LLM / Ollama βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 90 |
|
| 91 |
+
|
| 92 |
class LLMError(MediGuardError):
|
| 93 |
"""Base class for LLM-related errors."""
|
| 94 |
|
|
|
|
| 107 |
|
| 108 |
# ββ Biomarker domain βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
|
| 110 |
+
|
| 111 |
class BiomarkerError(MediGuardError):
|
| 112 |
"""Base class for biomarker-related errors."""
|
| 113 |
|
|
|
|
| 122 |
|
| 123 |
# ββ Medical analysis / workflow ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
|
| 125 |
+
|
| 126 |
class AnalysisError(MediGuardError):
|
| 127 |
"""The clinical-analysis workflow encountered an error."""
|
| 128 |
|
|
|
|
| 137 |
|
| 138 |
# ββ Cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
|
| 140 |
+
|
| 141 |
class CacheError(MediGuardError):
|
| 142 |
"""Base class for cache (Redis) errors."""
|
| 143 |
|
|
|
|
| 148 |
|
| 149 |
# ββ Observability ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
|
| 151 |
+
|
| 152 |
class ObservabilityError(MediGuardError):
|
| 153 |
"""Langfuse or metrics reporting failed (non-fatal)."""
|
| 154 |
|
| 155 |
|
| 156 |
# ββ Telegram bot βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
|
| 158 |
+
|
| 159 |
class TelegramError(MediGuardError):
|
| 160 |
"""Error from the Telegram bot integration."""
|
src/gradio_app.py
CHANGED
|
@@ -60,7 +60,7 @@ def _call_analyze(biomarkers_json: str) -> str:
|
|
| 60 |
summary = data.get("conversational_summary") or json.dumps(data, indent=2)
|
| 61 |
return summary
|
| 62 |
except json.JSONDecodeError:
|
| 63 |
-
return
|
| 64 |
except Exception as exc:
|
| 65 |
return f"Error: {exc}"
|
| 66 |
|
|
@@ -96,10 +96,12 @@ def launch_gradio(share: bool = False, server_port: int = 7860) -> None:
|
|
| 96 |
model_selector = gr.Dropdown(
|
| 97 |
choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
|
| 98 |
value="llama-3.3-70b-versatile",
|
| 99 |
-
label="LLM Provider/Model"
|
| 100 |
)
|
| 101 |
|
| 102 |
-
ask_btn.click(
|
|
|
|
|
|
|
| 103 |
clear_btn.click(fn=lambda: ([], ""), outputs=[chatbot, question_input])
|
| 104 |
|
| 105 |
with gr.Tab("Analyze Biomarkers"):
|
|
@@ -115,16 +117,10 @@ def launch_gradio(share: bool = False, server_port: int = 7860) -> None:
|
|
| 115 |
with gr.Tab("Search Knowledge Base"):
|
| 116 |
with gr.Row():
|
| 117 |
search_input = gr.Textbox(
|
| 118 |
-
label="Search Query",
|
| 119 |
-
placeholder="e.g., diabetes management guidelines",
|
| 120 |
-
lines=2,
|
| 121 |
-
scale=3
|
| 122 |
)
|
| 123 |
search_mode = gr.Radio(
|
| 124 |
-
choices=["hybrid", "bm25", "vector"],
|
| 125 |
-
value="hybrid",
|
| 126 |
-
label="Search Strategy",
|
| 127 |
-
scale=1
|
| 128 |
)
|
| 129 |
search_btn = gr.Button("Search", variant="primary")
|
| 130 |
search_output = gr.Textbox(label="Results", lines=15, interactive=False)
|
|
|
|
| 60 |
summary = data.get("conversational_summary") or json.dumps(data, indent=2)
|
| 61 |
return summary
|
| 62 |
except json.JSONDecodeError:
|
| 63 |
+
return 'Invalid JSON. Please enter biomarkers as: {"Glucose": 185, "HbA1c": 8.2}'
|
| 64 |
except Exception as exc:
|
| 65 |
return f"Error: {exc}"
|
| 66 |
|
|
|
|
| 96 |
model_selector = gr.Dropdown(
|
| 97 |
choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
|
| 98 |
value="llama-3.3-70b-versatile",
|
| 99 |
+
label="LLM Provider/Model",
|
| 100 |
)
|
| 101 |
|
| 102 |
+
ask_btn.click(
|
| 103 |
+
fn=ask_stream, inputs=[question_input, chatbot, model_selector], outputs=[question_input, chatbot]
|
| 104 |
+
)
|
| 105 |
clear_btn.click(fn=lambda: ([], ""), outputs=[chatbot, question_input])
|
| 106 |
|
| 107 |
with gr.Tab("Analyze Biomarkers"):
|
|
|
|
| 117 |
with gr.Tab("Search Knowledge Base"):
|
| 118 |
with gr.Row():
|
| 119 |
search_input = gr.Textbox(
|
| 120 |
+
label="Search Query", placeholder="e.g., diabetes management guidelines", lines=2, scale=3
|
|
|
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
search_mode = gr.Radio(
|
| 123 |
+
choices=["hybrid", "bm25", "vector"], value="hybrid", label="Search Strategy", scale=1
|
|
|
|
|
|
|
|
|
|
| 124 |
)
|
| 125 |
search_btn = gr.Button("Search", variant="primary")
|
| 126 |
search_output = gr.Textbox(label="Results", lines=15, interactive=False)
|
src/llm_config.py
CHANGED
|
@@ -32,7 +32,7 @@ def _get_env_with_fallback(primary: str, fallback: str, default: str = "") -> st
|
|
| 32 |
|
| 33 |
def get_default_llm_provider() -> str:
|
| 34 |
"""Get default LLM provider dynamically from environment.
|
| 35 |
-
|
| 36 |
Supports both naming conventions:
|
| 37 |
- LLM_PROVIDER (simple)
|
| 38 |
- LLM__PROVIDER (pydantic nested)
|
|
@@ -68,17 +68,17 @@ def get_chat_model(
|
|
| 68 |
provider: Literal["groq", "gemini", "ollama"] | None = None,
|
| 69 |
model: str | None = None,
|
| 70 |
temperature: float = 0.0,
|
| 71 |
-
json_mode: bool = False
|
| 72 |
):
|
| 73 |
"""
|
| 74 |
Get a chat model from the specified provider.
|
| 75 |
-
|
| 76 |
Args:
|
| 77 |
provider: "groq" (free, fast), "gemini" (free), or "ollama" (local)
|
| 78 |
model: Model name (provider-specific)
|
| 79 |
temperature: Sampling temperature
|
| 80 |
json_mode: Whether to enable JSON output mode
|
| 81 |
-
|
| 82 |
Returns:
|
| 83 |
LangChain chat model instance
|
| 84 |
"""
|
|
@@ -91,8 +91,7 @@ def get_chat_model(
|
|
| 91 |
api_key = get_groq_api_key()
|
| 92 |
if not api_key:
|
| 93 |
raise ValueError(
|
| 94 |
-
"GROQ_API_KEY not found in environment.\
|
| 95 |
-
"Get your FREE API key at: https://console.groq.com/keys"
|
| 96 |
)
|
| 97 |
|
| 98 |
# Use model from environment or default
|
|
@@ -102,7 +101,7 @@ def get_chat_model(
|
|
| 102 |
model=model,
|
| 103 |
temperature=temperature,
|
| 104 |
api_key=api_key,
|
| 105 |
-
model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {}
|
| 106 |
)
|
| 107 |
|
| 108 |
elif provider == "gemini":
|
|
@@ -119,10 +118,7 @@ def get_chat_model(
|
|
| 119 |
model = model or get_gemini_model()
|
| 120 |
|
| 121 |
return ChatGoogleGenerativeAI(
|
| 122 |
-
model=model,
|
| 123 |
-
temperature=temperature,
|
| 124 |
-
google_api_key=api_key,
|
| 125 |
-
convert_system_message_to_human=True
|
| 126 |
)
|
| 127 |
|
| 128 |
elif provider == "ollama":
|
|
@@ -133,11 +129,7 @@ def get_chat_model(
|
|
| 133 |
|
| 134 |
model = model or "llama3.1:8b"
|
| 135 |
|
| 136 |
-
return ChatOllama(
|
| 137 |
-
model=model,
|
| 138 |
-
temperature=temperature,
|
| 139 |
-
format='json' if json_mode else None
|
| 140 |
-
)
|
| 141 |
|
| 142 |
else:
|
| 143 |
raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
|
|
@@ -151,13 +143,13 @@ def get_embedding_provider() -> str:
|
|
| 151 |
def get_embedding_model(provider: Literal["jina", "google", "huggingface", "ollama"] | None = None):
|
| 152 |
"""
|
| 153 |
Get embedding model for vector search.
|
| 154 |
-
|
| 155 |
Args:
|
| 156 |
provider: "jina" (high-quality), "google" (free), "huggingface" (local), or "ollama" (local)
|
| 157 |
-
|
| 158 |
Returns:
|
| 159 |
LangChain embedding model instance
|
| 160 |
-
|
| 161 |
Note:
|
| 162 |
For production use, prefer src.services.embeddings.service.make_embedding_service()
|
| 163 |
which has automatic fallback chain: Jina β Google β HuggingFace.
|
|
@@ -171,6 +163,7 @@ def get_embedding_model(provider: Literal["jina", "google", "huggingface", "olla
|
|
| 171 |
try:
|
| 172 |
# Use the embedding service for Jina
|
| 173 |
from src.services.embeddings.service import make_embedding_service
|
|
|
|
| 174 |
return make_embedding_service()
|
| 175 |
except Exception as e:
|
| 176 |
print(f"WARN: Jina embeddings failed: {e}")
|
|
@@ -189,10 +182,7 @@ def get_embedding_model(provider: Literal["jina", "google", "huggingface", "olla
|
|
| 189 |
return get_embedding_model("huggingface")
|
| 190 |
|
| 191 |
try:
|
| 192 |
-
return GoogleGenerativeAIEmbeddings(
|
| 193 |
-
model="models/text-embedding-004",
|
| 194 |
-
google_api_key=api_key
|
| 195 |
-
)
|
| 196 |
except Exception as e:
|
| 197 |
print(f"WARN: Google embeddings failed: {e}")
|
| 198 |
print("INFO: Falling back to HuggingFace embeddings...")
|
|
@@ -204,9 +194,7 @@ def get_embedding_model(provider: Literal["jina", "google", "huggingface", "olla
|
|
| 204 |
except ImportError:
|
| 205 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 206 |
|
| 207 |
-
return HuggingFaceEmbeddings(
|
| 208 |
-
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
| 209 |
-
)
|
| 210 |
|
| 211 |
elif provider == "ollama":
|
| 212 |
try:
|
|
@@ -226,7 +214,7 @@ class LLMConfig:
|
|
| 226 |
def __init__(self, provider: str | None = None, lazy: bool = True):
|
| 227 |
"""
|
| 228 |
Initialize all model clients.
|
| 229 |
-
|
| 230 |
Args:
|
| 231 |
provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local)
|
| 232 |
lazy: If True, defer model initialization until first use (avoids API key errors at import)
|
|
@@ -283,41 +271,21 @@ class LLMConfig:
|
|
| 283 |
print(f"Initializing LLM models with provider: {self.provider.upper()}")
|
| 284 |
|
| 285 |
# Fast model for structured tasks (planning, analysis)
|
| 286 |
-
self._planner = get_chat_model(
|
| 287 |
-
provider=self.provider,
|
| 288 |
-
temperature=0.0,
|
| 289 |
-
json_mode=True
|
| 290 |
-
)
|
| 291 |
|
| 292 |
# Fast model for biomarker analysis and quick tasks
|
| 293 |
-
self._analyzer = get_chat_model(
|
| 294 |
-
provider=self.provider,
|
| 295 |
-
temperature=0.0
|
| 296 |
-
)
|
| 297 |
|
| 298 |
# Medium model for RAG retrieval and explanation
|
| 299 |
-
self._explainer = get_chat_model(
|
| 300 |
-
provider=self.provider,
|
| 301 |
-
temperature=0.2
|
| 302 |
-
)
|
| 303 |
|
| 304 |
# Configurable synthesizers
|
| 305 |
-
self._synthesizer_7b = get_chat_model(
|
| 306 |
-
provider=self.provider,
|
| 307 |
-
temperature=0.2
|
| 308 |
-
)
|
| 309 |
|
| 310 |
-
self._synthesizer_8b = get_chat_model(
|
| 311 |
-
provider=self.provider,
|
| 312 |
-
temperature=0.2
|
| 313 |
-
)
|
| 314 |
|
| 315 |
# Director for Outer Loop
|
| 316 |
-
self._director = get_chat_model(
|
| 317 |
-
provider=self.provider,
|
| 318 |
-
temperature=0.0,
|
| 319 |
-
json_mode=True
|
| 320 |
-
)
|
| 321 |
|
| 322 |
# Embedding model for RAG
|
| 323 |
self._embedding_model = get_embedding_model()
|
|
|
|
| 32 |
|
| 33 |
def get_default_llm_provider() -> str:
|
| 34 |
"""Get default LLM provider dynamically from environment.
|
| 35 |
+
|
| 36 |
Supports both naming conventions:
|
| 37 |
- LLM_PROVIDER (simple)
|
| 38 |
- LLM__PROVIDER (pydantic nested)
|
|
|
|
| 68 |
provider: Literal["groq", "gemini", "ollama"] | None = None,
|
| 69 |
model: str | None = None,
|
| 70 |
temperature: float = 0.0,
|
| 71 |
+
json_mode: bool = False,
|
| 72 |
):
|
| 73 |
"""
|
| 74 |
Get a chat model from the specified provider.
|
| 75 |
+
|
| 76 |
Args:
|
| 77 |
provider: "groq" (free, fast), "gemini" (free), or "ollama" (local)
|
| 78 |
model: Model name (provider-specific)
|
| 79 |
temperature: Sampling temperature
|
| 80 |
json_mode: Whether to enable JSON output mode
|
| 81 |
+
|
| 82 |
Returns:
|
| 83 |
LangChain chat model instance
|
| 84 |
"""
|
|
|
|
| 91 |
api_key = get_groq_api_key()
|
| 92 |
if not api_key:
|
| 93 |
raise ValueError(
|
| 94 |
+
"GROQ_API_KEY not found in environment.\nGet your FREE API key at: https://console.groq.com/keys"
|
|
|
|
| 95 |
)
|
| 96 |
|
| 97 |
# Use model from environment or default
|
|
|
|
| 101 |
model=model,
|
| 102 |
temperature=temperature,
|
| 103 |
api_key=api_key,
|
| 104 |
+
model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {},
|
| 105 |
)
|
| 106 |
|
| 107 |
elif provider == "gemini":
|
|
|
|
| 118 |
model = model or get_gemini_model()
|
| 119 |
|
| 120 |
return ChatGoogleGenerativeAI(
|
| 121 |
+
model=model, temperature=temperature, google_api_key=api_key, convert_system_message_to_human=True
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
|
| 124 |
elif provider == "ollama":
|
|
|
|
| 129 |
|
| 130 |
model = model or "llama3.1:8b"
|
| 131 |
|
| 132 |
+
return ChatOllama(model=model, temperature=temperature, format="json" if json_mode else None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
else:
|
| 135 |
raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
|
|
|
|
| 143 |
def get_embedding_model(provider: Literal["jina", "google", "huggingface", "ollama"] | None = None):
|
| 144 |
"""
|
| 145 |
Get embedding model for vector search.
|
| 146 |
+
|
| 147 |
Args:
|
| 148 |
provider: "jina" (high-quality), "google" (free), "huggingface" (local), or "ollama" (local)
|
| 149 |
+
|
| 150 |
Returns:
|
| 151 |
LangChain embedding model instance
|
| 152 |
+
|
| 153 |
Note:
|
| 154 |
For production use, prefer src.services.embeddings.service.make_embedding_service()
|
| 155 |
which has automatic fallback chain: Jina β Google β HuggingFace.
|
|
|
|
| 163 |
try:
|
| 164 |
# Use the embedding service for Jina
|
| 165 |
from src.services.embeddings.service import make_embedding_service
|
| 166 |
+
|
| 167 |
return make_embedding_service()
|
| 168 |
except Exception as e:
|
| 169 |
print(f"WARN: Jina embeddings failed: {e}")
|
|
|
|
| 182 |
return get_embedding_model("huggingface")
|
| 183 |
|
| 184 |
try:
|
| 185 |
+
return GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=api_key)
|
|
|
|
|
|
|
|
|
|
| 186 |
except Exception as e:
|
| 187 |
print(f"WARN: Google embeddings failed: {e}")
|
| 188 |
print("INFO: Falling back to HuggingFace embeddings...")
|
|
|
|
| 194 |
except ImportError:
|
| 195 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 196 |
|
| 197 |
+
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
|
|
|
| 198 |
|
| 199 |
elif provider == "ollama":
|
| 200 |
try:
|
|
|
|
| 214 |
def __init__(self, provider: str | None = None, lazy: bool = True):
|
| 215 |
"""
|
| 216 |
Initialize all model clients.
|
| 217 |
+
|
| 218 |
Args:
|
| 219 |
provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local)
|
| 220 |
lazy: If True, defer model initialization until first use (avoids API key errors at import)
|
|
|
|
| 271 |
print(f"Initializing LLM models with provider: {self.provider.upper()}")
|
| 272 |
|
| 273 |
# Fast model for structured tasks (planning, analysis)
|
| 274 |
+
self._planner = get_chat_model(provider=self.provider, temperature=0.0, json_mode=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
# Fast model for biomarker analysis and quick tasks
|
| 277 |
+
self._analyzer = get_chat_model(provider=self.provider, temperature=0.0)
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
# Medium model for RAG retrieval and explanation
|
| 280 |
+
self._explainer = get_chat_model(provider=self.provider, temperature=0.2)
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
# Configurable synthesizers
|
| 283 |
+
self._synthesizer_7b = get_chat_model(provider=self.provider, temperature=0.2)
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
self._synthesizer_8b = get_chat_model(provider=self.provider, temperature=0.2)
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
# Director for Outer Loop
|
| 288 |
+
self._director = get_chat_model(provider=self.provider, temperature=0.0, json_mode=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
# Embedding model for RAG
|
| 291 |
self._embedding_model = get_embedding_model()
|
src/main.py
CHANGED
|
@@ -35,6 +35,7 @@ logger = logging.getLogger("mediguard")
|
|
| 35 |
# Lifespan
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
|
|
|
|
| 38 |
@asynccontextmanager
|
| 39 |
async def lifespan(app: FastAPI):
|
| 40 |
"""Initialise production services on startup, tear them down on shutdown."""
|
|
@@ -50,6 +51,7 @@ async def lifespan(app: FastAPI):
|
|
| 50 |
try:
|
| 51 |
from src.services.opensearch.client import make_opensearch_client
|
| 52 |
from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING
|
|
|
|
| 53 |
app.state.opensearch_client = make_opensearch_client()
|
| 54 |
app.state.opensearch_client.ensure_index(MEDICAL_CHUNKS_MAPPING)
|
| 55 |
logger.info("OpenSearch client ready")
|
|
@@ -60,6 +62,7 @@ async def lifespan(app: FastAPI):
|
|
| 60 |
# --- Embedding service ---
|
| 61 |
try:
|
| 62 |
from src.services.embeddings.service import make_embedding_service
|
|
|
|
| 63 |
app.state.embedding_service = make_embedding_service()
|
| 64 |
logger.info("Embedding service ready (provider=%s)", app.state.embedding_service.provider_name)
|
| 65 |
except Exception as exc:
|
|
@@ -69,6 +72,7 @@ async def lifespan(app: FastAPI):
|
|
| 69 |
# --- Redis cache ---
|
| 70 |
try:
|
| 71 |
from src.services.cache.redis_cache import make_redis_cache
|
|
|
|
| 72 |
app.state.cache = make_redis_cache()
|
| 73 |
logger.info("Redis cache ready")
|
| 74 |
except Exception as exc:
|
|
@@ -78,6 +82,7 @@ async def lifespan(app: FastAPI):
|
|
| 78 |
# --- Ollama LLM ---
|
| 79 |
try:
|
| 80 |
from src.services.ollama.client import make_ollama_client
|
|
|
|
| 81 |
app.state.ollama_client = make_ollama_client()
|
| 82 |
logger.info("Ollama client ready")
|
| 83 |
except Exception as exc:
|
|
@@ -87,6 +92,7 @@ async def lifespan(app: FastAPI):
|
|
| 87 |
# --- Langfuse tracer ---
|
| 88 |
try:
|
| 89 |
from src.services.langfuse.tracer import make_langfuse_tracer
|
|
|
|
| 90 |
app.state.tracer = make_langfuse_tracer()
|
| 91 |
logger.info("Langfuse tracer ready")
|
| 92 |
except Exception as exc:
|
|
@@ -98,6 +104,7 @@ async def lifespan(app: FastAPI):
|
|
| 98 |
from src.llm_config import get_llm
|
| 99 |
from src.services.agents.agentic_rag import AgenticRAGService
|
| 100 |
from src.services.agents.context import AgenticContext
|
|
|
|
| 101 |
if app.state.opensearch_client and app.state.embedding_service:
|
| 102 |
llm = get_llm()
|
| 103 |
ctx = AgenticContext(
|
|
@@ -119,6 +126,7 @@ async def lifespan(app: FastAPI):
|
|
| 119 |
# --- Legacy RagBot service (backward-compatible /analyze) ---
|
| 120 |
try:
|
| 121 |
from src.workflow import create_guild
|
|
|
|
| 122 |
guild = create_guild()
|
| 123 |
app.state.ragbot_service = guild
|
| 124 |
logger.info("RagBot service ready (ClinicalInsightGuild)")
|
|
@@ -130,6 +138,7 @@ async def lifespan(app: FastAPI):
|
|
| 130 |
try:
|
| 131 |
from src.llm_config import get_llm
|
| 132 |
from src.services.extraction.service import make_extraction_service
|
|
|
|
| 133 |
try:
|
| 134 |
llm = get_llm()
|
| 135 |
except Exception as e:
|
|
@@ -154,6 +163,7 @@ async def lifespan(app: FastAPI):
|
|
| 154 |
# App factory
|
| 155 |
# ---------------------------------------------------------------------------
|
| 156 |
|
|
|
|
| 157 |
def create_app() -> FastAPI:
|
| 158 |
"""Build and return the configured FastAPI application."""
|
| 159 |
settings = get_settings()
|
|
@@ -180,6 +190,7 @@ def create_app() -> FastAPI:
|
|
| 180 |
|
| 181 |
# --- Security & HIPAA Compliance ---
|
| 182 |
from src.middlewares import HIPAAAuditMiddleware, SecurityHeadersMiddleware
|
|
|
|
| 183 |
app.add_middleware(SecurityHeadersMiddleware)
|
| 184 |
app.add_middleware(HIPAAAuditMiddleware)
|
| 185 |
|
|
|
|
| 35 |
# Lifespan
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
|
| 38 |
+
|
| 39 |
@asynccontextmanager
|
| 40 |
async def lifespan(app: FastAPI):
|
| 41 |
"""Initialise production services on startup, tear them down on shutdown."""
|
|
|
|
| 51 |
try:
|
| 52 |
from src.services.opensearch.client import make_opensearch_client
|
| 53 |
from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING
|
| 54 |
+
|
| 55 |
app.state.opensearch_client = make_opensearch_client()
|
| 56 |
app.state.opensearch_client.ensure_index(MEDICAL_CHUNKS_MAPPING)
|
| 57 |
logger.info("OpenSearch client ready")
|
|
|
|
| 62 |
# --- Embedding service ---
|
| 63 |
try:
|
| 64 |
from src.services.embeddings.service import make_embedding_service
|
| 65 |
+
|
| 66 |
app.state.embedding_service = make_embedding_service()
|
| 67 |
logger.info("Embedding service ready (provider=%s)", app.state.embedding_service.provider_name)
|
| 68 |
except Exception as exc:
|
|
|
|
| 72 |
# --- Redis cache ---
|
| 73 |
try:
|
| 74 |
from src.services.cache.redis_cache import make_redis_cache
|
| 75 |
+
|
| 76 |
app.state.cache = make_redis_cache()
|
| 77 |
logger.info("Redis cache ready")
|
| 78 |
except Exception as exc:
|
|
|
|
| 82 |
# --- Ollama LLM ---
|
| 83 |
try:
|
| 84 |
from src.services.ollama.client import make_ollama_client
|
| 85 |
+
|
| 86 |
app.state.ollama_client = make_ollama_client()
|
| 87 |
logger.info("Ollama client ready")
|
| 88 |
except Exception as exc:
|
|
|
|
| 92 |
# --- Langfuse tracer ---
|
| 93 |
try:
|
| 94 |
from src.services.langfuse.tracer import make_langfuse_tracer
|
| 95 |
+
|
| 96 |
app.state.tracer = make_langfuse_tracer()
|
| 97 |
logger.info("Langfuse tracer ready")
|
| 98 |
except Exception as exc:
|
|
|
|
| 104 |
from src.llm_config import get_llm
|
| 105 |
from src.services.agents.agentic_rag import AgenticRAGService
|
| 106 |
from src.services.agents.context import AgenticContext
|
| 107 |
+
|
| 108 |
if app.state.opensearch_client and app.state.embedding_service:
|
| 109 |
llm = get_llm()
|
| 110 |
ctx = AgenticContext(
|
|
|
|
| 126 |
# --- Legacy RagBot service (backward-compatible /analyze) ---
|
| 127 |
try:
|
| 128 |
from src.workflow import create_guild
|
| 129 |
+
|
| 130 |
guild = create_guild()
|
| 131 |
app.state.ragbot_service = guild
|
| 132 |
logger.info("RagBot service ready (ClinicalInsightGuild)")
|
|
|
|
| 138 |
try:
|
| 139 |
from src.llm_config import get_llm
|
| 140 |
from src.services.extraction.service import make_extraction_service
|
| 141 |
+
|
| 142 |
try:
|
| 143 |
llm = get_llm()
|
| 144 |
except Exception as e:
|
|
|
|
| 163 |
# App factory
|
| 164 |
# ---------------------------------------------------------------------------
|
| 165 |
|
| 166 |
+
|
| 167 |
def create_app() -> FastAPI:
|
| 168 |
"""Build and return the configured FastAPI application."""
|
| 169 |
settings = get_settings()
|
|
|
|
| 190 |
|
| 191 |
# --- Security & HIPAA Compliance ---
|
| 192 |
from src.middlewares import HIPAAAuditMiddleware, SecurityHeadersMiddleware
|
| 193 |
+
|
| 194 |
app.add_middleware(SecurityHeadersMiddleware)
|
| 195 |
app.add_middleware(HIPAAAuditMiddleware)
|
| 196 |
|
src/middlewares.py
CHANGED
|
@@ -27,8 +27,20 @@ logger = logging.getLogger("mediguard.audit")
|
|
| 27 |
|
| 28 |
# Sensitive fields that should NEVER be logged
|
| 29 |
SENSITIVE_FIELDS = {
|
| 30 |
-
"biomarkers",
|
| 31 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
}
|
| 33 |
|
| 34 |
# Endpoints that require audit logging
|
|
@@ -65,14 +77,14 @@ def _redact_body(body_dict: dict) -> dict:
|
|
| 65 |
class HIPAAAuditMiddleware(BaseHTTPMiddleware):
|
| 66 |
"""
|
| 67 |
HIPAA-compliant audit logging middleware.
|
| 68 |
-
|
| 69 |
Features:
|
| 70 |
- Generates unique request IDs for traceability
|
| 71 |
- Logs request metadata WITHOUT PHI/biomarker values
|
| 72 |
- Creates audit trail for all medical analysis requests
|
| 73 |
- Tracks request timing and response status
|
| 74 |
- Hashes sensitive identifiers for correlation
|
| 75 |
-
|
| 76 |
Audit logs are structured JSON for easy SIEM integration.
|
| 77 |
"""
|
| 78 |
|
|
@@ -116,7 +128,9 @@ class HIPAAAuditMiddleware(BaseHTTPMiddleware):
|
|
| 116 |
audit_entry["request_fields"] = list(redacted.keys())
|
| 117 |
# Log presence of biomarkers without values
|
| 118 |
if "biomarkers" in body_dict:
|
| 119 |
-
audit_entry["biomarker_count"] =
|
|
|
|
|
|
|
| 120 |
except Exception as exc:
|
| 121 |
logger.debug("Failed to audit POST body: %s", exc)
|
| 122 |
|
|
|
|
| 27 |
|
| 28 |
# Sensitive fields that should NEVER be logged
|
| 29 |
SENSITIVE_FIELDS = {
|
| 30 |
+
"biomarkers",
|
| 31 |
+
"patient_context",
|
| 32 |
+
"patient_id",
|
| 33 |
+
"age",
|
| 34 |
+
"gender",
|
| 35 |
+
"bmi",
|
| 36 |
+
"ssn",
|
| 37 |
+
"mrn",
|
| 38 |
+
"name",
|
| 39 |
+
"address",
|
| 40 |
+
"phone",
|
| 41 |
+
"email",
|
| 42 |
+
"dob",
|
| 43 |
+
"date_of_birth",
|
| 44 |
}
|
| 45 |
|
| 46 |
# Endpoints that require audit logging
|
|
|
|
| 77 |
class HIPAAAuditMiddleware(BaseHTTPMiddleware):
|
| 78 |
"""
|
| 79 |
HIPAA-compliant audit logging middleware.
|
| 80 |
+
|
| 81 |
Features:
|
| 82 |
- Generates unique request IDs for traceability
|
| 83 |
- Logs request metadata WITHOUT PHI/biomarker values
|
| 84 |
- Creates audit trail for all medical analysis requests
|
| 85 |
- Tracks request timing and response status
|
| 86 |
- Hashes sensitive identifiers for correlation
|
| 87 |
+
|
| 88 |
Audit logs are structured JSON for easy SIEM integration.
|
| 89 |
"""
|
| 90 |
|
|
|
|
| 128 |
audit_entry["request_fields"] = list(redacted.keys())
|
| 129 |
# Log presence of biomarkers without values
|
| 130 |
if "biomarkers" in body_dict:
|
| 131 |
+
audit_entry["biomarker_count"] = (
|
| 132 |
+
len(body_dict["biomarkers"]) if isinstance(body_dict["biomarkers"], dict) else 1
|
| 133 |
+
)
|
| 134 |
except Exception as exc:
|
| 135 |
logger.debug("Failed to audit POST body: %s", exc)
|
| 136 |
|
src/pdf_processor.py
CHANGED
|
@@ -32,11 +32,11 @@ class PDFProcessor:
|
|
| 32 |
pdf_directory: str = "data/medical_pdfs",
|
| 33 |
vector_store_path: str = "data/vector_stores",
|
| 34 |
chunk_size: int = 1000,
|
| 35 |
-
chunk_overlap: int = 200
|
| 36 |
):
|
| 37 |
"""
|
| 38 |
Initialize PDF processor.
|
| 39 |
-
|
| 40 |
Args:
|
| 41 |
pdf_directory: Path to folder containing medical PDFs
|
| 42 |
vector_store_path: Path to save FAISS vector stores
|
|
@@ -57,13 +57,13 @@ class PDFProcessor:
|
|
| 57 |
chunk_size=chunk_size,
|
| 58 |
chunk_overlap=chunk_overlap,
|
| 59 |
separators=["\n\n", "\n", ". ", " ", ""],
|
| 60 |
-
length_function=len
|
| 61 |
)
|
| 62 |
|
| 63 |
def load_pdfs(self) -> list[Document]:
|
| 64 |
"""
|
| 65 |
Load all PDF documents from the configured directory.
|
| 66 |
-
|
| 67 |
Returns:
|
| 68 |
List of Document objects with content and metadata
|
| 69 |
"""
|
|
@@ -89,8 +89,8 @@ class PDFProcessor:
|
|
| 89 |
|
| 90 |
# Add source filename to metadata
|
| 91 |
for doc in docs:
|
| 92 |
-
doc.metadata[
|
| 93 |
-
doc.metadata[
|
| 94 |
|
| 95 |
documents.extend(docs)
|
| 96 |
print(f" OK: Loaded {len(docs)} pages from {pdf_path.name}")
|
|
@@ -104,10 +104,10 @@ class PDFProcessor:
|
|
| 104 |
def chunk_documents(self, documents: list[Document]) -> list[Document]:
|
| 105 |
"""
|
| 106 |
Split documents into chunks for RAG retrieval.
|
| 107 |
-
|
| 108 |
Args:
|
| 109 |
documents: List of loaded documents
|
| 110 |
-
|
| 111 |
Returns:
|
| 112 |
List of chunked documents with preserved metadata
|
| 113 |
"""
|
|
@@ -121,7 +121,7 @@ class PDFProcessor:
|
|
| 121 |
|
| 122 |
# Add chunk index to metadata
|
| 123 |
for i, chunk in enumerate(chunks):
|
| 124 |
-
chunk.metadata[
|
| 125 |
|
| 126 |
print(f"OK: Created {len(chunks)} chunks from {len(documents)} pages")
|
| 127 |
print(f" Average chunk size: {sum(len(c.page_content) for c in chunks) // len(chunks)} characters")
|
|
@@ -129,19 +129,16 @@ class PDFProcessor:
|
|
| 129 |
return chunks
|
| 130 |
|
| 131 |
def create_vector_store(
|
| 132 |
-
self,
|
| 133 |
-
chunks: list[Document],
|
| 134 |
-
embedding_model,
|
| 135 |
-
store_name: str = "medical_knowledge"
|
| 136 |
) -> FAISS:
|
| 137 |
"""
|
| 138 |
Create FAISS vector store from document chunks.
|
| 139 |
-
|
| 140 |
Args:
|
| 141 |
chunks: Document chunks to embed
|
| 142 |
embedding_model: Embedding model (from llm_config)
|
| 143 |
store_name: Name for the vector store
|
| 144 |
-
|
| 145 |
Returns:
|
| 146 |
FAISS vector store object
|
| 147 |
"""
|
|
@@ -150,10 +147,7 @@ class PDFProcessor:
|
|
| 150 |
print("(This may take a few minutes...)")
|
| 151 |
|
| 152 |
# Create FAISS vector store
|
| 153 |
-
vector_store = FAISS.from_documents(
|
| 154 |
-
documents=chunks,
|
| 155 |
-
embedding=embedding_model
|
| 156 |
-
)
|
| 157 |
|
| 158 |
# Save to disk
|
| 159 |
save_path = self.vector_store_path / f"{store_name}.faiss"
|
|
@@ -163,18 +157,14 @@ class PDFProcessor:
|
|
| 163 |
|
| 164 |
return vector_store
|
| 165 |
|
| 166 |
-
def load_vector_store(
|
| 167 |
-
self,
|
| 168 |
-
embedding_model,
|
| 169 |
-
store_name: str = "medical_knowledge"
|
| 170 |
-
) -> FAISS | None:
|
| 171 |
"""
|
| 172 |
Load existing vector store from disk.
|
| 173 |
-
|
| 174 |
Args:
|
| 175 |
embedding_model: Embedding model (must match the one used to create store)
|
| 176 |
store_name: Name of the vector store
|
| 177 |
-
|
| 178 |
Returns:
|
| 179 |
FAISS vector store or None if not found
|
| 180 |
"""
|
|
@@ -192,7 +182,7 @@ class PDFProcessor:
|
|
| 192 |
str(self.vector_store_path),
|
| 193 |
embedding_model,
|
| 194 |
index_name=store_name,
|
| 195 |
-
allow_dangerous_deserialization=True
|
| 196 |
)
|
| 197 |
print(f"OK: Loaded vector store from: {store_path}")
|
| 198 |
return vector_store
|
|
@@ -202,19 +192,16 @@ class PDFProcessor:
|
|
| 202 |
return None
|
| 203 |
|
| 204 |
def create_retrievers(
|
| 205 |
-
self,
|
| 206 |
-
embedding_model,
|
| 207 |
-
store_name: str = "medical_knowledge",
|
| 208 |
-
force_rebuild: bool = False
|
| 209 |
) -> dict:
|
| 210 |
"""
|
| 211 |
Create or load retrievers for RAG.
|
| 212 |
-
|
| 213 |
Args:
|
| 214 |
embedding_model: Embedding model
|
| 215 |
store_name: Vector store name
|
| 216 |
force_rebuild: If True, rebuild vector store even if it exists
|
| 217 |
-
|
| 218 |
Returns:
|
| 219 |
Dictionary of retrievers for different purposes
|
| 220 |
"""
|
|
@@ -238,18 +225,10 @@ class PDFProcessor:
|
|
| 238 |
|
| 239 |
# Create specialized retrievers
|
| 240 |
retrievers = {
|
| 241 |
-
"disease_explainer": vector_store.as_retriever(
|
| 242 |
-
|
| 243 |
-
),
|
| 244 |
-
"
|
| 245 |
-
search_kwargs={"k": 3}
|
| 246 |
-
),
|
| 247 |
-
"clinical_guidelines": vector_store.as_retriever(
|
| 248 |
-
search_kwargs={"k": 3}
|
| 249 |
-
),
|
| 250 |
-
"general": vector_store.as_retriever(
|
| 251 |
-
search_kwargs={"k": 5}
|
| 252 |
-
)
|
| 253 |
}
|
| 254 |
|
| 255 |
print(f"\nOK: Created {len(retrievers)} specialized retrievers")
|
|
@@ -259,12 +238,12 @@ class PDFProcessor:
|
|
| 259 |
def setup_knowledge_base(embedding_model=None, force_rebuild: bool = False, use_configured_embeddings: bool = True):
|
| 260 |
"""
|
| 261 |
Convenience function to set up the complete knowledge base.
|
| 262 |
-
|
| 263 |
Args:
|
| 264 |
embedding_model: Embedding model (optional if use_configured_embeddings=True)
|
| 265 |
force_rebuild: Force rebuild of vector stores
|
| 266 |
use_configured_embeddings: Use embedding provider from EMBEDDING_PROVIDER env var
|
| 267 |
-
|
| 268 |
Returns:
|
| 269 |
Dictionary of retrievers ready for use
|
| 270 |
"""
|
|
@@ -281,9 +260,7 @@ def setup_knowledge_base(embedding_model=None, force_rebuild: bool = False, use_
|
|
| 281 |
|
| 282 |
processor = PDFProcessor()
|
| 283 |
retrievers = processor.create_retrievers(
|
| 284 |
-
embedding_model,
|
| 285 |
-
store_name="medical_knowledge",
|
| 286 |
-
force_rebuild=force_rebuild
|
| 287 |
)
|
| 288 |
|
| 289 |
if retrievers:
|
|
@@ -300,19 +277,16 @@ def get_all_retrievers(force_rebuild: bool = False) -> dict:
|
|
| 300 |
"""
|
| 301 |
Quick function to get all retrievers using configured embedding provider.
|
| 302 |
Used by workflow.py to initialize the Clinical Insight Guild.
|
| 303 |
-
|
| 304 |
Uses EMBEDDING_PROVIDER from .env: "google" (default), "huggingface", or "ollama"
|
| 305 |
-
|
| 306 |
Args:
|
| 307 |
force_rebuild: Force rebuild of vector stores
|
| 308 |
-
|
| 309 |
Returns:
|
| 310 |
Dictionary of retrievers for all agent types
|
| 311 |
"""
|
| 312 |
-
return setup_knowledge_base(
|
| 313 |
-
use_configured_embeddings=True,
|
| 314 |
-
force_rebuild=force_rebuild
|
| 315 |
-
)
|
| 316 |
|
| 317 |
|
| 318 |
if __name__ == "__main__":
|
|
@@ -323,16 +297,16 @@ if __name__ == "__main__":
|
|
| 323 |
# Add parent directory to path for imports
|
| 324 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 325 |
|
| 326 |
-
print("\n" + "="*70)
|
| 327 |
print("MediGuard AI - PDF Knowledge Base Builder")
|
| 328 |
-
print("="*70)
|
| 329 |
print("\nUsing configured embedding provider from .env")
|
| 330 |
print(" EMBEDDING_PROVIDER options: google (default), huggingface, ollama")
|
| 331 |
-
print("="*70)
|
| 332 |
|
| 333 |
retrievers = setup_knowledge_base(
|
| 334 |
use_configured_embeddings=True, # Use configured provider
|
| 335 |
-
force_rebuild=False
|
| 336 |
)
|
| 337 |
|
| 338 |
if retrievers:
|
|
|
|
| 32 |
pdf_directory: str = "data/medical_pdfs",
|
| 33 |
vector_store_path: str = "data/vector_stores",
|
| 34 |
chunk_size: int = 1000,
|
| 35 |
+
chunk_overlap: int = 200,
|
| 36 |
):
|
| 37 |
"""
|
| 38 |
Initialize PDF processor.
|
| 39 |
+
|
| 40 |
Args:
|
| 41 |
pdf_directory: Path to folder containing medical PDFs
|
| 42 |
vector_store_path: Path to save FAISS vector stores
|
|
|
|
| 57 |
chunk_size=chunk_size,
|
| 58 |
chunk_overlap=chunk_overlap,
|
| 59 |
separators=["\n\n", "\n", ". ", " ", ""],
|
| 60 |
+
length_function=len,
|
| 61 |
)
|
| 62 |
|
| 63 |
def load_pdfs(self) -> list[Document]:
|
| 64 |
"""
|
| 65 |
Load all PDF documents from the configured directory.
|
| 66 |
+
|
| 67 |
Returns:
|
| 68 |
List of Document objects with content and metadata
|
| 69 |
"""
|
|
|
|
| 89 |
|
| 90 |
# Add source filename to metadata
|
| 91 |
for doc in docs:
|
| 92 |
+
doc.metadata["source_file"] = pdf_path.name
|
| 93 |
+
doc.metadata["source_path"] = str(pdf_path)
|
| 94 |
|
| 95 |
documents.extend(docs)
|
| 96 |
print(f" OK: Loaded {len(docs)} pages from {pdf_path.name}")
|
|
|
|
| 104 |
def chunk_documents(self, documents: list[Document]) -> list[Document]:
|
| 105 |
"""
|
| 106 |
Split documents into chunks for RAG retrieval.
|
| 107 |
+
|
| 108 |
Args:
|
| 109 |
documents: List of loaded documents
|
| 110 |
+
|
| 111 |
Returns:
|
| 112 |
List of chunked documents with preserved metadata
|
| 113 |
"""
|
|
|
|
| 121 |
|
| 122 |
# Add chunk index to metadata
|
| 123 |
for i, chunk in enumerate(chunks):
|
| 124 |
+
chunk.metadata["chunk_id"] = i
|
| 125 |
|
| 126 |
print(f"OK: Created {len(chunks)} chunks from {len(documents)} pages")
|
| 127 |
print(f" Average chunk size: {sum(len(c.page_content) for c in chunks) // len(chunks)} characters")
|
|
|
|
| 129 |
return chunks
|
| 130 |
|
| 131 |
def create_vector_store(
|
| 132 |
+
self, chunks: list[Document], embedding_model, store_name: str = "medical_knowledge"
|
|
|
|
|
|
|
|
|
|
| 133 |
) -> FAISS:
|
| 134 |
"""
|
| 135 |
Create FAISS vector store from document chunks.
|
| 136 |
+
|
| 137 |
Args:
|
| 138 |
chunks: Document chunks to embed
|
| 139 |
embedding_model: Embedding model (from llm_config)
|
| 140 |
store_name: Name for the vector store
|
| 141 |
+
|
| 142 |
Returns:
|
| 143 |
FAISS vector store object
|
| 144 |
"""
|
|
|
|
| 147 |
print("(This may take a few minutes...)")
|
| 148 |
|
| 149 |
# Create FAISS vector store
|
| 150 |
+
vector_store = FAISS.from_documents(documents=chunks, embedding=embedding_model)
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# Save to disk
|
| 153 |
save_path = self.vector_store_path / f"{store_name}.faiss"
|
|
|
|
| 157 |
|
| 158 |
return vector_store
|
| 159 |
|
| 160 |
+
def load_vector_store(self, embedding_model, store_name: str = "medical_knowledge") -> FAISS | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
"""
|
| 162 |
Load existing vector store from disk.
|
| 163 |
+
|
| 164 |
Args:
|
| 165 |
embedding_model: Embedding model (must match the one used to create store)
|
| 166 |
store_name: Name of the vector store
|
| 167 |
+
|
| 168 |
Returns:
|
| 169 |
FAISS vector store or None if not found
|
| 170 |
"""
|
|
|
|
| 182 |
str(self.vector_store_path),
|
| 183 |
embedding_model,
|
| 184 |
index_name=store_name,
|
| 185 |
+
allow_dangerous_deserialization=True,
|
| 186 |
)
|
| 187 |
print(f"OK: Loaded vector store from: {store_path}")
|
| 188 |
return vector_store
|
|
|
|
| 192 |
return None
|
| 193 |
|
| 194 |
def create_retrievers(
|
| 195 |
+
self, embedding_model, store_name: str = "medical_knowledge", force_rebuild: bool = False
|
|
|
|
|
|
|
|
|
|
| 196 |
) -> dict:
|
| 197 |
"""
|
| 198 |
Create or load retrievers for RAG.
|
| 199 |
+
|
| 200 |
Args:
|
| 201 |
embedding_model: Embedding model
|
| 202 |
store_name: Vector store name
|
| 203 |
force_rebuild: If True, rebuild vector store even if it exists
|
| 204 |
+
|
| 205 |
Returns:
|
| 206 |
Dictionary of retrievers for different purposes
|
| 207 |
"""
|
|
|
|
| 225 |
|
| 226 |
# Create specialized retrievers
|
| 227 |
retrievers = {
|
| 228 |
+
"disease_explainer": vector_store.as_retriever(search_kwargs={"k": 5}),
|
| 229 |
+
"biomarker_linker": vector_store.as_retriever(search_kwargs={"k": 3}),
|
| 230 |
+
"clinical_guidelines": vector_store.as_retriever(search_kwargs={"k": 3}),
|
| 231 |
+
"general": vector_store.as_retriever(search_kwargs={"k": 5}),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
}
|
| 233 |
|
| 234 |
print(f"\nOK: Created {len(retrievers)} specialized retrievers")
|
|
|
|
| 238 |
def setup_knowledge_base(embedding_model=None, force_rebuild: bool = False, use_configured_embeddings: bool = True):
|
| 239 |
"""
|
| 240 |
Convenience function to set up the complete knowledge base.
|
| 241 |
+
|
| 242 |
Args:
|
| 243 |
embedding_model: Embedding model (optional if use_configured_embeddings=True)
|
| 244 |
force_rebuild: Force rebuild of vector stores
|
| 245 |
use_configured_embeddings: Use embedding provider from EMBEDDING_PROVIDER env var
|
| 246 |
+
|
| 247 |
Returns:
|
| 248 |
Dictionary of retrievers ready for use
|
| 249 |
"""
|
|
|
|
| 260 |
|
| 261 |
processor = PDFProcessor()
|
| 262 |
retrievers = processor.create_retrievers(
|
| 263 |
+
embedding_model, store_name="medical_knowledge", force_rebuild=force_rebuild
|
|
|
|
|
|
|
| 264 |
)
|
| 265 |
|
| 266 |
if retrievers:
|
|
|
|
| 277 |
"""
|
| 278 |
Quick function to get all retrievers using configured embedding provider.
|
| 279 |
Used by workflow.py to initialize the Clinical Insight Guild.
|
| 280 |
+
|
| 281 |
Uses EMBEDDING_PROVIDER from .env: "google" (default), "huggingface", or "ollama"
|
| 282 |
+
|
| 283 |
Args:
|
| 284 |
force_rebuild: Force rebuild of vector stores
|
| 285 |
+
|
| 286 |
Returns:
|
| 287 |
Dictionary of retrievers for all agent types
|
| 288 |
"""
|
| 289 |
+
return setup_knowledge_base(use_configured_embeddings=True, force_rebuild=force_rebuild)
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
if __name__ == "__main__":
|
|
|
|
| 297 |
# Add parent directory to path for imports
|
| 298 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 299 |
|
| 300 |
+
print("\n" + "=" * 70)
|
| 301 |
print("MediGuard AI - PDF Knowledge Base Builder")
|
| 302 |
+
print("=" * 70)
|
| 303 |
print("\nUsing configured embedding provider from .env")
|
| 304 |
print(" EMBEDDING_PROVIDER options: google (default), huggingface, ollama")
|
| 305 |
+
print("=" * 70)
|
| 306 |
|
| 307 |
retrievers = setup_knowledge_base(
|
| 308 |
use_configured_embeddings=True, # Use configured provider
|
| 309 |
+
force_rebuild=False,
|
| 310 |
)
|
| 311 |
|
| 312 |
if retrievers:
|
src/repositories/analysis.py
CHANGED
|
@@ -21,19 +21,10 @@ class AnalysisRepository:
|
|
| 21 |
return analysis
|
| 22 |
|
| 23 |
def get_by_request_id(self, request_id: str) -> PatientAnalysis | None:
|
| 24 |
-
return (
|
| 25 |
-
self.db.query(PatientAnalysis)
|
| 26 |
-
.filter(PatientAnalysis.request_id == request_id)
|
| 27 |
-
.first()
|
| 28 |
-
)
|
| 29 |
|
| 30 |
def list_recent(self, limit: int = 20) -> list[PatientAnalysis]:
|
| 31 |
-
return (
|
| 32 |
-
self.db.query(PatientAnalysis)
|
| 33 |
-
.order_by(PatientAnalysis.created_at.desc())
|
| 34 |
-
.limit(limit)
|
| 35 |
-
.all()
|
| 36 |
-
)
|
| 37 |
|
| 38 |
def count(self) -> int:
|
| 39 |
return self.db.query(PatientAnalysis).count()
|
|
|
|
| 21 |
return analysis
|
| 22 |
|
| 23 |
def get_by_request_id(self, request_id: str) -> PatientAnalysis | None:
|
| 24 |
+
return self.db.query(PatientAnalysis).filter(PatientAnalysis.request_id == request_id).first()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def list_recent(self, limit: int = 20) -> list[PatientAnalysis]:
|
| 27 |
+
return self.db.query(PatientAnalysis).order_by(PatientAnalysis.created_at.desc()).limit(limit).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def count(self) -> int:
|
| 30 |
return self.db.query(PatientAnalysis).count()
|
src/repositories/document.py
CHANGED
|
@@ -16,11 +16,7 @@ class DocumentRepository:
|
|
| 16 |
self.db = db
|
| 17 |
|
| 18 |
def upsert(self, doc: MedicalDocument) -> MedicalDocument:
|
| 19 |
-
existing = (
|
| 20 |
-
self.db.query(MedicalDocument)
|
| 21 |
-
.filter(MedicalDocument.content_hash == doc.content_hash)
|
| 22 |
-
.first()
|
| 23 |
-
)
|
| 24 |
if existing:
|
| 25 |
existing.parse_status = doc.parse_status
|
| 26 |
existing.chunk_count = doc.chunk_count
|
|
@@ -35,12 +31,7 @@ class DocumentRepository:
|
|
| 35 |
return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first()
|
| 36 |
|
| 37 |
def list_all(self, limit: int = 100) -> list[MedicalDocument]:
|
| 38 |
-
return (
|
| 39 |
-
self.db.query(MedicalDocument)
|
| 40 |
-
.order_by(MedicalDocument.created_at.desc())
|
| 41 |
-
.limit(limit)
|
| 42 |
-
.all()
|
| 43 |
-
)
|
| 44 |
|
| 45 |
def count(self) -> int:
|
| 46 |
return self.db.query(MedicalDocument).count()
|
|
|
|
| 16 |
self.db = db
|
| 17 |
|
| 18 |
def upsert(self, doc: MedicalDocument) -> MedicalDocument:
|
| 19 |
+
existing = self.db.query(MedicalDocument).filter(MedicalDocument.content_hash == doc.content_hash).first()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
if existing:
|
| 21 |
existing.parse_status = doc.parse_status
|
| 22 |
existing.chunk_count = doc.chunk_count
|
|
|
|
| 31 |
return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first()
|
| 32 |
|
| 33 |
def list_all(self, limit: int = 100) -> list[MedicalDocument]:
|
| 34 |
+
return self.db.query(MedicalDocument).order_by(MedicalDocument.created_at.desc()).limit(limit).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def count(self) -> int:
|
| 37 |
return self.db.query(MedicalDocument).count()
|
src/routers/analyze.py
CHANGED
|
@@ -32,13 +32,7 @@ _executor = ThreadPoolExecutor(max_workers=4)
|
|
| 32 |
|
| 33 |
def _score_disease_heuristic(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 34 |
"""Rule-based disease scoring (NOT ML prediction)."""
|
| 35 |
-
scores = {
|
| 36 |
-
"Diabetes": 0.0,
|
| 37 |
-
"Anemia": 0.0,
|
| 38 |
-
"Heart Disease": 0.0,
|
| 39 |
-
"Thrombocytopenia": 0.0,
|
| 40 |
-
"Thalassemia": 0.0
|
| 41 |
-
}
|
| 42 |
|
| 43 |
# Diabetes indicators
|
| 44 |
glucose = biomarkers.get("Glucose")
|
|
@@ -96,11 +90,7 @@ def _score_disease_heuristic(biomarkers: dict[str, float]) -> dict[str, Any]:
|
|
| 96 |
else:
|
| 97 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 98 |
|
| 99 |
-
return {
|
| 100 |
-
"disease": top_disease,
|
| 101 |
-
"confidence": confidence,
|
| 102 |
-
"probabilities": probabilities
|
| 103 |
-
}
|
| 104 |
|
| 105 |
|
| 106 |
async def _run_guild_analysis(
|
|
@@ -123,16 +113,12 @@ async def _run_guild_analysis(
|
|
| 123 |
try:
|
| 124 |
# Run sync function in thread pool
|
| 125 |
from src.state import PatientInput
|
|
|
|
| 126 |
patient_input = PatientInput(
|
| 127 |
-
biomarkers=biomarkers,
|
| 128 |
-
patient_context=patient_ctx,
|
| 129 |
-
model_prediction=model_prediction
|
| 130 |
)
|
| 131 |
loop = asyncio.get_running_loop()
|
| 132 |
-
result = await loop.run_in_executor(
|
| 133 |
-
_executor,
|
| 134 |
-
lambda: ragbot.run(patient_input)
|
| 135 |
-
)
|
| 136 |
except Exception as exc:
|
| 137 |
logger.exception("Guild analysis failed: %s", exc)
|
| 138 |
raise HTTPException(
|
|
@@ -143,10 +129,10 @@ async def _run_guild_analysis(
|
|
| 143 |
elapsed = (time.time() - t0) * 1000
|
| 144 |
|
| 145 |
# Build response from result
|
| 146 |
-
prediction = result.get(
|
| 147 |
-
analysis = result.get(
|
| 148 |
# Try to extract the conversational_summary if it's there
|
| 149 |
-
conversational_summary = analysis.get(
|
| 150 |
|
| 151 |
return AnalysisResponse(
|
| 152 |
status="success",
|
|
|
|
| 32 |
|
| 33 |
def _score_disease_heuristic(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 34 |
"""Rule-based disease scoring (NOT ML prediction)."""
|
| 35 |
+
scores = {"Diabetes": 0.0, "Anemia": 0.0, "Heart Disease": 0.0, "Thrombocytopenia": 0.0, "Thalassemia": 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Diabetes indicators
|
| 38 |
glucose = biomarkers.get("Glucose")
|
|
|
|
| 90 |
else:
|
| 91 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 92 |
|
| 93 |
+
return {"disease": top_disease, "confidence": confidence, "probabilities": probabilities}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
async def _run_guild_analysis(
|
|
|
|
| 113 |
try:
|
| 114 |
# Run sync function in thread pool
|
| 115 |
from src.state import PatientInput
|
| 116 |
+
|
| 117 |
patient_input = PatientInput(
|
| 118 |
+
biomarkers=biomarkers, patient_context=patient_ctx, model_prediction=model_prediction
|
|
|
|
|
|
|
| 119 |
)
|
| 120 |
loop = asyncio.get_running_loop()
|
| 121 |
+
result = await loop.run_in_executor(_executor, lambda: ragbot.run(patient_input))
|
|
|
|
|
|
|
|
|
|
| 122 |
except Exception as exc:
|
| 123 |
logger.exception("Guild analysis failed: %s", exc)
|
| 124 |
raise HTTPException(
|
|
|
|
| 129 |
elapsed = (time.time() - t0) * 1000
|
| 130 |
|
| 131 |
# Build response from result
|
| 132 |
+
prediction = result.get("model_prediction")
|
| 133 |
+
analysis = result.get("final_response", {})
|
| 134 |
# Try to extract the conversational_summary if it's there
|
| 135 |
+
conversational_summary = analysis.get("conversational_summary") if isinstance(analysis, dict) else str(analysis)
|
| 136 |
|
| 137 |
return AnalysisResponse(
|
| 138 |
status="success",
|
src/routers/ask.py
CHANGED
|
@@ -71,7 +71,7 @@ async def _stream_rag_response(
|
|
| 71 |
) -> AsyncGenerator[str, None]:
|
| 72 |
"""
|
| 73 |
Generate Server-Sent Events for streaming RAG responses.
|
| 74 |
-
|
| 75 |
Event types:
|
| 76 |
- status: Pipeline stage updates
|
| 77 |
- token: Individual response tokens
|
|
@@ -94,7 +94,7 @@ async def _stream_rag_response(
|
|
| 94 |
query=question,
|
| 95 |
biomarkers=biomarkers,
|
| 96 |
patient_context=patient_context,
|
| 97 |
-
)
|
| 98 |
)
|
| 99 |
|
| 100 |
# Send retrieval metadata
|
|
@@ -110,7 +110,7 @@ async def _stream_rag_response(
|
|
| 110 |
words = answer.split()
|
| 111 |
chunk_size = 3 # Send 3 words at a time
|
| 112 |
for i in range(0, len(words), chunk_size):
|
| 113 |
-
chunk = " ".join(words[i:i + chunk_size])
|
| 114 |
if i + chunk_size < len(words):
|
| 115 |
chunk += " "
|
| 116 |
yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
|
|
@@ -129,21 +129,21 @@ async def _stream_rag_response(
|
|
| 129 |
async def ask_medical_question_stream(body: AskRequest, request: Request):
|
| 130 |
"""
|
| 131 |
Stream a medical Q&A response via Server-Sent Events (SSE).
|
| 132 |
-
|
| 133 |
Events:
|
| 134 |
- `status`: Pipeline stage updates (guardrail, retrieve, grade, generate)
|
| 135 |
- `token`: Individual response tokens for real-time display
|
| 136 |
- `metadata`: Retrieval statistics (documents found, relevance scores)
|
| 137 |
- `done`: Completion signal with timing info
|
| 138 |
- `error`: Error details if something fails
|
| 139 |
-
|
| 140 |
Example client code (JavaScript):
|
| 141 |
```javascript
|
| 142 |
const eventSource = new EventSource('/ask/stream', {
|
| 143 |
method: 'POST',
|
| 144 |
body: JSON.stringify({ question: 'What causes high glucose?' })
|
| 145 |
});
|
| 146 |
-
|
| 147 |
eventSource.addEventListener('token', (e) => {
|
| 148 |
const data = JSON.parse(e.data);
|
| 149 |
document.getElementById('response').innerHTML += data.text;
|
|
@@ -178,10 +178,5 @@ async def submit_feedback(body: FeedbackRequest, request: Request):
|
|
| 178 |
"""Submit user feedback for an analysis or RAG response."""
|
| 179 |
tracer = getattr(request.app.state, "tracer", None)
|
| 180 |
if tracer:
|
| 181 |
-
tracer.score(
|
| 182 |
-
trace_id=body.request_id,
|
| 183 |
-
name="user-feedback",
|
| 184 |
-
value=body.score,
|
| 185 |
-
comment=body.comment
|
| 186 |
-
)
|
| 187 |
return FeedbackResponse(request_id=body.request_id)
|
|
|
|
| 71 |
) -> AsyncGenerator[str, None]:
|
| 72 |
"""
|
| 73 |
Generate Server-Sent Events for streaming RAG responses.
|
| 74 |
+
|
| 75 |
Event types:
|
| 76 |
- status: Pipeline stage updates
|
| 77 |
- token: Individual response tokens
|
|
|
|
| 94 |
query=question,
|
| 95 |
biomarkers=biomarkers,
|
| 96 |
patient_context=patient_context,
|
| 97 |
+
),
|
| 98 |
)
|
| 99 |
|
| 100 |
# Send retrieval metadata
|
|
|
|
| 110 |
words = answer.split()
|
| 111 |
chunk_size = 3 # Send 3 words at a time
|
| 112 |
for i in range(0, len(words), chunk_size):
|
| 113 |
+
chunk = " ".join(words[i : i + chunk_size])
|
| 114 |
if i + chunk_size < len(words):
|
| 115 |
chunk += " "
|
| 116 |
yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
|
|
|
|
| 129 |
async def ask_medical_question_stream(body: AskRequest, request: Request):
|
| 130 |
"""
|
| 131 |
Stream a medical Q&A response via Server-Sent Events (SSE).
|
| 132 |
+
|
| 133 |
Events:
|
| 134 |
- `status`: Pipeline stage updates (guardrail, retrieve, grade, generate)
|
| 135 |
- `token`: Individual response tokens for real-time display
|
| 136 |
- `metadata`: Retrieval statistics (documents found, relevance scores)
|
| 137 |
- `done`: Completion signal with timing info
|
| 138 |
- `error`: Error details if something fails
|
| 139 |
+
|
| 140 |
Example client code (JavaScript):
|
| 141 |
```javascript
|
| 142 |
const eventSource = new EventSource('/ask/stream', {
|
| 143 |
method: 'POST',
|
| 144 |
body: JSON.stringify({ question: 'What causes high glucose?' })
|
| 145 |
});
|
| 146 |
+
|
| 147 |
eventSource.addEventListener('token', (e) => {
|
| 148 |
const data = JSON.parse(e.data);
|
| 149 |
document.getElementById('response').innerHTML += data.text;
|
|
|
|
| 178 |
"""Submit user feedback for an analysis or RAG response."""
|
| 179 |
tracer = getattr(request.app.state, "tracer", None)
|
| 180 |
if tracer:
|
| 181 |
+
tracer.score(trace_id=body.request_id, name="user-feedback", value=body.score, comment=body.comment)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
return FeedbackResponse(request_id=body.request_id)
|
src/routers/health.py
CHANGED
|
@@ -42,6 +42,7 @@ async def readiness_check(request: Request) -> HealthResponse:
|
|
| 42 |
from sqlalchemy import text
|
| 43 |
|
| 44 |
from src.database import _engine
|
|
|
|
| 45 |
engine = _engine()
|
| 46 |
if engine is not None:
|
| 47 |
t0 = time.time()
|
|
@@ -62,7 +63,13 @@ async def readiness_check(request: Request) -> HealthResponse:
|
|
| 62 |
info = os_client.health()
|
| 63 |
latency = (time.time() - t0) * 1000
|
| 64 |
os_status = info.get("status", "unknown")
|
| 65 |
-
services.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
else:
|
| 67 |
services.append(ServiceHealth(name="opensearch", status="unavailable"))
|
| 68 |
except Exception as exc:
|
|
@@ -90,7 +97,9 @@ async def readiness_check(request: Request) -> HealthResponse:
|
|
| 90 |
health_info = ollama.health()
|
| 91 |
latency = (time.time() - t0) * 1000
|
| 92 |
is_healthy = isinstance(health_info, dict) and health_info.get("status") == "ok"
|
| 93 |
-
services.append(
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
services.append(ServiceHealth(name="ollama", status="unavailable"))
|
| 96 |
except Exception as exc:
|
|
@@ -110,6 +119,7 @@ async def readiness_check(request: Request) -> HealthResponse:
|
|
| 110 |
# --- FAISS (local retriever) ---
|
| 111 |
try:
|
| 112 |
from src.services.retrieval.factory import make_retriever
|
|
|
|
| 113 |
retriever = make_retriever(backend="faiss")
|
| 114 |
if retriever is not None:
|
| 115 |
doc_count = retriever.doc_count()
|
|
|
|
| 42 |
from sqlalchemy import text
|
| 43 |
|
| 44 |
from src.database import _engine
|
| 45 |
+
|
| 46 |
engine = _engine()
|
| 47 |
if engine is not None:
|
| 48 |
t0 = time.time()
|
|
|
|
| 63 |
info = os_client.health()
|
| 64 |
latency = (time.time() - t0) * 1000
|
| 65 |
os_status = info.get("status", "unknown")
|
| 66 |
+
services.append(
|
| 67 |
+
ServiceHealth(
|
| 68 |
+
name="opensearch",
|
| 69 |
+
status="ok" if os_status in ("green", "yellow") else "degraded",
|
| 70 |
+
latency_ms=round(latency, 1),
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
else:
|
| 74 |
services.append(ServiceHealth(name="opensearch", status="unavailable"))
|
| 75 |
except Exception as exc:
|
|
|
|
| 97 |
health_info = ollama.health()
|
| 98 |
latency = (time.time() - t0) * 1000
|
| 99 |
is_healthy = isinstance(health_info, dict) and health_info.get("status") == "ok"
|
| 100 |
+
services.append(
|
| 101 |
+
ServiceHealth(name="ollama", status="ok" if is_healthy else "degraded", latency_ms=round(latency, 1))
|
| 102 |
+
)
|
| 103 |
else:
|
| 104 |
services.append(ServiceHealth(name="ollama", status="unavailable"))
|
| 105 |
except Exception as exc:
|
|
|
|
| 119 |
# --- FAISS (local retriever) ---
|
| 120 |
try:
|
| 121 |
from src.services.retrieval.factory import make_retriever
|
| 122 |
+
|
| 123 |
retriever = make_retriever(backend="faiss")
|
| 124 |
if retriever is not None:
|
| 125 |
doc_count = retriever.doc_count()
|
src/schemas/schemas.py
CHANGED
|
@@ -29,11 +29,13 @@ class NaturalAnalysisRequest(BaseModel):
|
|
| 29 |
"""Natural language biomarker analysis request."""
|
| 30 |
|
| 31 |
message: str = Field(
|
| 32 |
-
...,
|
|
|
|
|
|
|
| 33 |
description="Natural language message with biomarker values",
|
| 34 |
)
|
| 35 |
patient_context: PatientContext | None = Field(
|
| 36 |
-
default_factory=PatientContext,
|
| 37 |
)
|
| 38 |
|
| 39 |
|
|
@@ -41,10 +43,11 @@ class StructuredAnalysisRequest(BaseModel):
|
|
| 41 |
"""Structured biomarker analysis request."""
|
| 42 |
|
| 43 |
biomarkers: dict[str, float] = Field(
|
| 44 |
-
...,
|
|
|
|
| 45 |
)
|
| 46 |
patient_context: PatientContext | None = Field(
|
| 47 |
-
default_factory=PatientContext,
|
| 48 |
)
|
| 49 |
|
| 50 |
@field_validator("biomarkers")
|
|
@@ -59,14 +62,18 @@ class AskRequest(BaseModel):
|
|
| 59 |
"""Freeβform medical question (agentic RAG pipeline)."""
|
| 60 |
|
| 61 |
question: str = Field(
|
| 62 |
-
...,
|
|
|
|
|
|
|
| 63 |
description="Medical question",
|
| 64 |
)
|
| 65 |
biomarkers: dict[str, float] | None = Field(
|
| 66 |
-
None,
|
|
|
|
| 67 |
)
|
| 68 |
patient_context: str | None = Field(
|
| 69 |
-
None,
|
|
|
|
| 70 |
)
|
| 71 |
|
| 72 |
|
|
@@ -80,6 +87,7 @@ class SearchRequest(BaseModel):
|
|
| 80 |
|
| 81 |
class FeedbackRequest(BaseModel):
|
| 82 |
"""User feedback for RAG responses."""
|
|
|
|
| 83 |
request_id: str = Field(..., description="ID of the request being rated")
|
| 84 |
score: float = Field(..., ge=0, le=1, description="Normalized score 0.0 to 1.0")
|
| 85 |
comment: str | None = Field(None, description="Optional textual feedback")
|
|
|
|
| 29 |
"""Natural language biomarker analysis request."""
|
| 30 |
|
| 31 |
message: str = Field(
|
| 32 |
+
...,
|
| 33 |
+
min_length=5,
|
| 34 |
+
max_length=2000,
|
| 35 |
description="Natural language message with biomarker values",
|
| 36 |
)
|
| 37 |
patient_context: PatientContext | None = Field(
|
| 38 |
+
default_factory=lambda: PatientContext(),
|
| 39 |
)
|
| 40 |
|
| 41 |
|
|
|
|
| 43 |
"""Structured biomarker analysis request."""
|
| 44 |
|
| 45 |
biomarkers: dict[str, float] = Field(
|
| 46 |
+
...,
|
| 47 |
+
description="Dict of biomarker name β measured value",
|
| 48 |
)
|
| 49 |
patient_context: PatientContext | None = Field(
|
| 50 |
+
default_factory=lambda: PatientContext(),
|
| 51 |
)
|
| 52 |
|
| 53 |
@field_validator("biomarkers")
|
|
|
|
| 62 |
"""Freeβform medical question (agentic RAG pipeline)."""
|
| 63 |
|
| 64 |
question: str = Field(
|
| 65 |
+
...,
|
| 66 |
+
min_length=3,
|
| 67 |
+
max_length=4000,
|
| 68 |
description="Medical question",
|
| 69 |
)
|
| 70 |
biomarkers: dict[str, float] | None = Field(
|
| 71 |
+
None,
|
| 72 |
+
description="Optional biomarker context",
|
| 73 |
)
|
| 74 |
patient_context: str | None = Field(
|
| 75 |
+
None,
|
| 76 |
+
description="Freeβtext patient context",
|
| 77 |
)
|
| 78 |
|
| 79 |
|
|
|
|
| 87 |
|
| 88 |
class FeedbackRequest(BaseModel):
|
| 89 |
"""User feedback for RAG responses."""
|
| 90 |
+
|
| 91 |
request_id: str = Field(..., description="ID of the request being rated")
|
| 92 |
score: float = Field(..., ge=0, le=1, description="Normalized score 0.0 to 1.0")
|
| 93 |
comment: str | None = Field(None, description="Optional textual feedback")
|
src/services/agents/context.py
CHANGED
|
@@ -15,10 +15,10 @@ from typing import Any
|
|
| 15 |
class AgenticContext:
|
| 16 |
"""Immutable runtime context for agentic RAG nodes."""
|
| 17 |
|
| 18 |
-
llm: Any
|
| 19 |
-
embedding_service: Any
|
| 20 |
-
opensearch_client: Any
|
| 21 |
-
cache: Any
|
| 22 |
-
tracer: Any
|
| 23 |
-
guild: Any | None = None
|
| 24 |
retriever: Any | None = None # BaseRetriever (FAISS or OpenSearch)
|
|
|
|
| 15 |
class AgenticContext:
|
| 16 |
"""Immutable runtime context for agentic RAG nodes."""
|
| 17 |
|
| 18 |
+
llm: Any # LangChain chat model
|
| 19 |
+
embedding_service: Any # EmbeddingService
|
| 20 |
+
opensearch_client: Any # OpenSearchClient
|
| 21 |
+
cache: Any # RedisCache
|
| 22 |
+
tracer: Any # LangfuseTracer
|
| 23 |
+
guild: Any | None = None # ClinicalInsightGuild (original workflow)
|
| 24 |
retriever: Any | None = None # BaseRetriever (FAISS or OpenSearch)
|
src/services/agents/nodes/retrieve_node.py
CHANGED
|
@@ -69,10 +69,7 @@ def retrieve_node(state: dict, *, context: Any) -> dict:
|
|
| 69 |
documents = [
|
| 70 |
{
|
| 71 |
"content": h.get("_source", {}).get("chunk_text", ""),
|
| 72 |
-
"metadata": {
|
| 73 |
-
k: v for k, v in h.get("_source", {}).items()
|
| 74 |
-
if k != "chunk_text"
|
| 75 |
-
},
|
| 76 |
"score": h.get("_score", 0.0),
|
| 77 |
}
|
| 78 |
for h in raw_hits
|
|
@@ -88,10 +85,7 @@ def retrieve_node(state: dict, *, context: Any) -> dict:
|
|
| 88 |
documents = [
|
| 89 |
{
|
| 90 |
"content": h.get("_source", {}).get("chunk_text", ""),
|
| 91 |
-
"metadata": {
|
| 92 |
-
k: v for k, v in h.get("_source", {}).items()
|
| 93 |
-
if k != "chunk_text"
|
| 94 |
-
},
|
| 95 |
"score": h.get("_score", 0.0),
|
| 96 |
}
|
| 97 |
for h in raw_hits
|
|
|
|
| 69 |
documents = [
|
| 70 |
{
|
| 71 |
"content": h.get("_source", {}).get("chunk_text", ""),
|
| 72 |
+
"metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
|
|
|
|
|
|
|
|
|
|
| 73 |
"score": h.get("_score", 0.0),
|
| 74 |
}
|
| 75 |
for h in raw_hits
|
|
|
|
| 85 |
documents = [
|
| 86 |
{
|
| 87 |
"content": h.get("_source", {}).get("chunk_text", ""),
|
| 88 |
+
"metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
|
|
|
|
|
|
|
|
|
|
| 89 |
"score": h.get("_score", 0.0),
|
| 90 |
}
|
| 91 |
for h in raw_hits
|
src/services/agents/state.py
CHANGED
|
@@ -13,7 +13,7 @@ from typing import Annotated, Any
|
|
| 13 |
from typing_extensions import TypedDict
|
| 14 |
|
| 15 |
|
| 16 |
-
class AgenticRAGState(TypedDict):
|
| 17 |
"""State flowing through the agentic RAG graph."""
|
| 18 |
|
| 19 |
# ββ Input ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -22,8 +22,8 @@ class AgenticRAGState(TypedDict):
|
|
| 22 |
patient_context: dict[str, Any] | None
|
| 23 |
|
| 24 |
# ββ Guardrail ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
-
guardrail_score: float
|
| 26 |
-
is_in_scope: bool
|
| 27 |
|
| 28 |
# ββ Retrieval ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
retrieved_documents: list[dict[str, Any]]
|
|
@@ -39,7 +39,7 @@ class AgenticRAGState(TypedDict):
|
|
| 39 |
rewritten_query: str | None
|
| 40 |
|
| 41 |
# ββ Generation / routing βββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
-
routing_decision: str
|
| 43 |
final_answer: str | None
|
| 44 |
analysis_result: dict[str, Any] | None
|
| 45 |
|
|
|
|
| 13 |
from typing_extensions import TypedDict
|
| 14 |
|
| 15 |
|
| 16 |
+
class AgenticRAGState(TypedDict, total=False):
|
| 17 |
"""State flowing through the agentic RAG graph."""
|
| 18 |
|
| 19 |
# ββ Input ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 22 |
patient_context: dict[str, Any] | None
|
| 23 |
|
| 24 |
# ββ Guardrail ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
guardrail_score: float # 0-100 medical-relevance score
|
| 26 |
+
is_in_scope: bool # passed guardrail?
|
| 27 |
|
| 28 |
# ββ Retrieval ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
retrieved_documents: list[dict[str, Any]]
|
|
|
|
| 39 |
rewritten_query: str | None
|
| 40 |
|
| 41 |
# ββ Generation / routing βββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
routing_decision: str # "analyze" | "rag_answer" | "out_of_scope"
|
| 43 |
final_answer: str | None
|
| 44 |
analysis_result: dict[str, Any] | None
|
| 45 |
|