Spaces:
Sleeping
Sleeping
Nikhil Pravin Pise commited on
Commit ·
696f787
1
Parent(s): fd5543a
Fix codebase issues: linting, types, tests, and security.
Browse files- Resolved over 3,000 ruff linting violations
- Enforced strict type checking with mypy
- Fixed infinite loop in pytest suite by migrating obsolete tests
- Remediated security warnings flagged by Bandit
This view is limited to 50 files because it contains too many changes.
See raw diff
- Makefile +4 -2
- airflow/dags/ingest_pdfs.py +7 -4
- alembic/env.py +9 -11
- alembic/versions/001_initial.py +81 -0
- api/app/main.py +10 -13
- api/app/routes/analyze.py +31 -37
- api/app/routes/biomarkers.py +16 -17
- api/app/routes/health.py +12 -15
- api/app/services/extraction.py +27 -27
- api/app/services/ragbot.py +69 -62
- archive/evolution/__init__.py +11 -17
- archive/evolution/director.py +55 -55
- archive/evolution/pareto.py +48 -46
- archive/sop_evolution.py +2 -1
- {tests → archive/tests}/test_evolution_loop.py +37 -37
- {tests → archive/tests}/test_evolution_quick.py +14 -13
- docker-compose.yml +20 -0
- gradio_launcher.py +24 -0
- huggingface/app.py +180 -133
- pytest.ini +3 -0
- requirements.txt +0 -41
- scripts/chat.py +73 -73
- scripts/monitor_test.py +1 -1
- scripts/setup_embeddings.py +16 -16
- scripts/test_chat_demo.py +4 -5
- scripts/test_extraction.py +8 -7
- src/agents/biomarker_analyzer.py +23 -23
- src/agents/biomarker_linker.py +41 -41
- src/agents/clinical_guidelines.py +43 -42
- src/agents/confidence_assessor.py +46 -46
- src/agents/disease_explainer.py +36 -34
- src/agents/response_synthesizer.py +52 -51
- src/biomarker_normalization.py +1 -2
- src/biomarker_validator.py +31 -31
- src/config.py +15 -14
- src/database.py +2 -2
- src/dependencies.py +0 -3
- src/evaluation/__init__.py +7 -7
- src/evaluation/evaluators.py +72 -70
- src/exceptions.py +2 -3
- src/gradio_app.py +65 -25
- src/llm_config.py +68 -67
- src/main.py +20 -23
- src/middlewares.py +24 -23
- src/pdf_processor.py +52 -53
- src/repositories/analysis.py +2 -4
- src/repositories/document.py +2 -4
- src/routers/analyze.py +28 -32
- src/routers/ask.py +25 -12
- src/routers/health.py +9 -7
Makefile
CHANGED
|
@@ -117,12 +117,14 @@ index-pdfs: ## Parse and index all medical PDFs
|
|
| 117 |
from pathlib import Path; \
|
| 118 |
from src.services.pdf_parser.service import make_pdf_parser_service; \
|
| 119 |
from src.services.indexing.service import IndexingService; \
|
|
|
|
| 120 |
from src.services.embeddings.service import make_embedding_service; \
|
| 121 |
from src.services.opensearch.client import make_opensearch_client; \
|
| 122 |
parser = make_pdf_parser_service(); \
|
| 123 |
-
|
|
|
|
| 124 |
docs = parser.parse_directory(Path('data/medical_pdfs')); \
|
| 125 |
-
[idx.index_text(d.full_text,
|
| 126 |
print(f'Indexed {len(docs)} documents')"
|
| 127 |
|
| 128 |
# ---------------------------------------------------------------------------
|
|
|
|
| 117 |
from pathlib import Path; \
|
| 118 |
from src.services.pdf_parser.service import make_pdf_parser_service; \
|
| 119 |
from src.services.indexing.service import IndexingService; \
|
| 120 |
+
from src.services.indexing.text_chunker import MedicalTextChunker; \
|
| 121 |
from src.services.embeddings.service import make_embedding_service; \
|
| 122 |
from src.services.opensearch.client import make_opensearch_client; \
|
| 123 |
parser = make_pdf_parser_service(); \
|
| 124 |
+
chunker = MedicalTextChunker(); \
|
| 125 |
+
idx = IndexingService(chunker, make_embedding_service(), make_opensearch_client()); \
|
| 126 |
docs = parser.parse_directory(Path('data/medical_pdfs')); \
|
| 127 |
+
[idx.index_text(d.full_text, title=d.filename, source_file=d.filename) for d in docs if d.full_text]; \
|
| 128 |
print(f'Indexed {len(docs)} documents')"
|
| 129 |
|
| 130 |
# ---------------------------------------------------------------------------
|
airflow/dags/ingest_pdfs.py
CHANGED
|
@@ -9,9 +9,10 @@ from __future__ import annotations
|
|
| 9 |
|
| 10 |
from datetime import datetime, timedelta
|
| 11 |
|
| 12 |
-
from airflow import DAG
|
| 13 |
from airflow.operators.python import PythonOperator
|
| 14 |
|
|
|
|
|
|
|
| 15 |
default_args = {
|
| 16 |
"owner": "mediguard",
|
| 17 |
"retries": 2,
|
|
@@ -26,23 +27,25 @@ def _ingest_pdfs(**kwargs):
|
|
| 26 |
|
| 27 |
from src.services.embeddings.service import make_embedding_service
|
| 28 |
from src.services.indexing.service import IndexingService
|
|
|
|
| 29 |
from src.services.opensearch.client import make_opensearch_client
|
| 30 |
from src.services.pdf_parser.service import make_pdf_parser_service
|
| 31 |
from src.settings import get_settings
|
| 32 |
|
| 33 |
settings = get_settings()
|
| 34 |
-
pdf_dir = Path(settings.
|
| 35 |
|
| 36 |
parser = make_pdf_parser_service()
|
| 37 |
embedding_svc = make_embedding_service()
|
| 38 |
os_client = make_opensearch_client()
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
docs = parser.parse_directory(pdf_dir)
|
| 42 |
indexed = 0
|
| 43 |
for doc in docs:
|
| 44 |
if doc.full_text and not doc.error:
|
| 45 |
-
indexing_svc.index_text(doc.full_text,
|
| 46 |
indexed += 1
|
| 47 |
|
| 48 |
print(f"Ingested {indexed}/{len(docs)} documents")
|
|
|
|
| 9 |
|
| 10 |
from datetime import datetime, timedelta
|
| 11 |
|
|
|
|
| 12 |
from airflow.operators.python import PythonOperator
|
| 13 |
|
| 14 |
+
from airflow import DAG
|
| 15 |
+
|
| 16 |
default_args = {
|
| 17 |
"owner": "mediguard",
|
| 18 |
"retries": 2,
|
|
|
|
| 27 |
|
| 28 |
from src.services.embeddings.service import make_embedding_service
|
| 29 |
from src.services.indexing.service import IndexingService
|
| 30 |
+
from src.services.indexing.text_chunker import MedicalTextChunker
|
| 31 |
from src.services.opensearch.client import make_opensearch_client
|
| 32 |
from src.services.pdf_parser.service import make_pdf_parser_service
|
| 33 |
from src.settings import get_settings
|
| 34 |
|
| 35 |
settings = get_settings()
|
| 36 |
+
pdf_dir = Path(settings.pdf.pdf_directory)
|
| 37 |
|
| 38 |
parser = make_pdf_parser_service()
|
| 39 |
embedding_svc = make_embedding_service()
|
| 40 |
os_client = make_opensearch_client()
|
| 41 |
+
chunker = MedicalTextChunker(target_words=settings.chunking.chunk_size, overlap_words=settings.chunking.chunk_overlap, min_words=settings.chunking.min_chunk_size)
|
| 42 |
+
indexing_svc = IndexingService(chunker, embedding_svc, os_client)
|
| 43 |
|
| 44 |
docs = parser.parse_directory(pdf_dir)
|
| 45 |
indexed = 0
|
| 46 |
for doc in docs:
|
| 47 |
if doc.full_text and not doc.error:
|
| 48 |
+
indexing_svc.index_text(doc.full_text, title=doc.filename, source_file=doc.filename)
|
| 49 |
indexed += 1
|
| 50 |
|
| 51 |
print(f"Ingested {indexed}/{len(docs)} documents")
|
alembic/env.py
CHANGED
|
@@ -1,25 +1,23 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
from sqlalchemy import engine_from_config
|
| 4 |
-
from sqlalchemy import pool, create_engine
|
| 5 |
-
|
| 6 |
-
from alembic import context
|
| 7 |
|
| 8 |
# ---------------------------------------------------------------------------
|
| 9 |
# MediGuard AI — Alembic env.py
|
| 10 |
# Pull DB URL from settings so we never hard-code credentials.
|
| 11 |
# ---------------------------------------------------------------------------
|
| 12 |
import sys
|
| 13 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Make sure the project root is on sys.path
|
| 16 |
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 17 |
|
| 18 |
-
from src.settings import get_settings # noqa: E402
|
| 19 |
-
from src.database import Base # noqa: E402
|
| 20 |
-
|
| 21 |
# Import all models so Alembic's autogenerate can see them
|
| 22 |
-
import src.models.analysis # noqa: F401
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# this is the Alembic Config object, which provides
|
| 25 |
# access to the values within the .ini file in use.
|
|
|
|
| 1 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
# ---------------------------------------------------------------------------
|
| 4 |
# MediGuard AI — Alembic env.py
|
| 5 |
# Pull DB URL from settings so we never hard-code credentials.
|
| 6 |
# ---------------------------------------------------------------------------
|
| 7 |
import sys
|
| 8 |
+
from logging.config import fileConfig
|
| 9 |
+
|
| 10 |
+
from sqlalchemy import engine_from_config, pool
|
| 11 |
+
|
| 12 |
+
from alembic import context
|
| 13 |
|
| 14 |
# Make sure the project root is on sys.path
|
| 15 |
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
# Import all models so Alembic's autogenerate can see them
|
| 18 |
+
import src.models.analysis # noqa: F401
|
| 19 |
+
from src.database import Base
|
| 20 |
+
from src.settings import get_settings
|
| 21 |
|
| 22 |
# this is the Alembic Config object, which provides
|
| 23 |
# access to the values within the .ini file in use.
|
alembic/versions/001_initial.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""initial_tables
|
| 2 |
+
|
| 3 |
+
Revision ID: 001
|
| 4 |
+
Revises:
|
| 5 |
+
Create Date: 2026-02-24 20:58:00.000000
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
import sqlalchemy as sa
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
|
| 12 |
+
# revision identifiers, used by Alembic.
|
| 13 |
+
revision = '001'
|
| 14 |
+
down_revision = None
|
| 15 |
+
branch_labels = None
|
| 16 |
+
depends_on = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def upgrade() -> None:
|
| 20 |
+
op.create_table(
|
| 21 |
+
'patient_analyses',
|
| 22 |
+
sa.Column('id', sa.String(length=36), nullable=False),
|
| 23 |
+
sa.Column('request_id', sa.String(length=64), nullable=False),
|
| 24 |
+
sa.Column('biomarkers', sa.JSON(), nullable=False),
|
| 25 |
+
sa.Column('patient_context', sa.JSON(), nullable=True),
|
| 26 |
+
sa.Column('predicted_disease', sa.String(length=128), nullable=False),
|
| 27 |
+
sa.Column('confidence', sa.Float(), nullable=False),
|
| 28 |
+
sa.Column('probabilities', sa.JSON(), nullable=True),
|
| 29 |
+
sa.Column('analysis_result', sa.JSON(), nullable=True),
|
| 30 |
+
sa.Column('safety_alerts', sa.JSON(), nullable=True),
|
| 31 |
+
sa.Column('sop_version', sa.String(length=64), nullable=True),
|
| 32 |
+
sa.Column('processing_time_ms', sa.Float(), nullable=False),
|
| 33 |
+
sa.Column('model_provider', sa.String(length=32), nullable=True),
|
| 34 |
+
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
| 35 |
+
sa.PrimaryKeyConstraint('id')
|
| 36 |
+
)
|
| 37 |
+
op.create_index(op.f('ix_patient_analyses_request_id'), 'patient_analyses', ['request_id'], unique=True)
|
| 38 |
+
|
| 39 |
+
op.create_table(
|
| 40 |
+
'medical_documents',
|
| 41 |
+
sa.Column('id', sa.String(length=36), nullable=False),
|
| 42 |
+
sa.Column('title', sa.String(length=512), nullable=False),
|
| 43 |
+
sa.Column('source', sa.String(length=512), nullable=False),
|
| 44 |
+
sa.Column('source_type', sa.String(length=32), nullable=False),
|
| 45 |
+
sa.Column('authors', sa.Text(), nullable=True),
|
| 46 |
+
sa.Column('abstract', sa.Text(), nullable=True),
|
| 47 |
+
sa.Column('content_hash', sa.String(length=64), nullable=True),
|
| 48 |
+
sa.Column('page_count', sa.Integer(), nullable=True),
|
| 49 |
+
sa.Column('chunk_count', sa.Integer(), nullable=True),
|
| 50 |
+
sa.Column('parse_status', sa.String(length=32), nullable=False),
|
| 51 |
+
sa.Column('metadata_json', sa.JSON(), nullable=True),
|
| 52 |
+
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
| 53 |
+
sa.Column('indexed_at', sa.DateTime(timezone=True), nullable=True),
|
| 54 |
+
sa.PrimaryKeyConstraint('id'),
|
| 55 |
+
sa.UniqueConstraint('content_hash')
|
| 56 |
+
)
|
| 57 |
+
op.create_index(op.f('ix_medical_documents_title'), 'medical_documents', ['title'], unique=False)
|
| 58 |
+
|
| 59 |
+
op.create_table(
|
| 60 |
+
'sop_versions',
|
| 61 |
+
sa.Column('id', sa.String(length=36), nullable=False),
|
| 62 |
+
sa.Column('version_tag', sa.String(length=64), nullable=False),
|
| 63 |
+
sa.Column('parameters', sa.JSON(), nullable=False),
|
| 64 |
+
sa.Column('evaluation_scores', sa.JSON(), nullable=True),
|
| 65 |
+
sa.Column('parent_version', sa.String(length=64), nullable=True),
|
| 66 |
+
sa.Column('is_active', sa.Boolean(), nullable=False),
|
| 67 |
+
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
| 68 |
+
sa.PrimaryKeyConstraint('id')
|
| 69 |
+
)
|
| 70 |
+
op.create_index(op.f('ix_sop_versions_version_tag'), 'sop_versions', ['version_tag'], unique=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def downgrade() -> None:
|
| 74 |
+
op.drop_index(op.f('ix_sop_versions_version_tag'), table_name='sop_versions')
|
| 75 |
+
op.drop_table('sop_versions')
|
| 76 |
+
|
| 77 |
+
op.drop_index(op.f('ix_medical_documents_title'), table_name='medical_documents')
|
| 78 |
+
op.drop_table('medical_documents')
|
| 79 |
+
|
| 80 |
+
op.drop_index(op.f('ix_patient_analyses_request_id'), table_name='patient_analyses')
|
| 81 |
+
op.drop_table('patient_analyses')
|
api/app/main.py
CHANGED
|
@@ -3,22 +3,19 @@ RagBot FastAPI Main Application
|
|
| 3 |
Medical biomarker analysis API
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import os
|
| 7 |
-
import sys
|
| 8 |
import logging
|
| 9 |
-
|
| 10 |
from contextlib import asynccontextmanager
|
| 11 |
|
| 12 |
from fastapi import FastAPI, Request, status
|
|
|
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
from fastapi.responses import JSONResponse
|
| 15 |
-
from fastapi.exceptions import RequestValidationError
|
| 16 |
|
| 17 |
from app import __version__
|
| 18 |
-
from app.routes import
|
| 19 |
from app.services.ragbot import get_ragbot_service
|
| 20 |
|
| 21 |
-
|
| 22 |
# Configure logging
|
| 23 |
logging.basicConfig(
|
| 24 |
level=logging.INFO,
|
|
@@ -40,7 +37,7 @@ async def lifespan(app: FastAPI):
|
|
| 40 |
logger.info("=" * 70)
|
| 41 |
logger.info("Starting RagBot API Server")
|
| 42 |
logger.info("=" * 70)
|
| 43 |
-
|
| 44 |
# Startup: Initialize RagBot service
|
| 45 |
try:
|
| 46 |
ragbot_service = get_ragbot_service()
|
|
@@ -49,12 +46,12 @@ async def lifespan(app: FastAPI):
|
|
| 49 |
except Exception as e:
|
| 50 |
logger.error(f"Failed to initialize RagBot service: {e}")
|
| 51 |
logger.warning("API will start but health checks will fail")
|
| 52 |
-
|
| 53 |
logger.info("API server ready to accept requests")
|
| 54 |
logger.info("=" * 70)
|
| 55 |
-
|
| 56 |
yield # Server runs here
|
| 57 |
-
|
| 58 |
# Shutdown
|
| 59 |
logger.info("Shutting down RagBot API Server")
|
| 60 |
|
|
@@ -178,14 +175,14 @@ async def api_v1_info():
|
|
| 178 |
|
| 179 |
if __name__ == "__main__":
|
| 180 |
import uvicorn
|
| 181 |
-
|
| 182 |
# Get configuration from environment
|
| 183 |
host = os.getenv("API_HOST", "0.0.0.0")
|
| 184 |
port = int(os.getenv("API_PORT", "8000"))
|
| 185 |
reload = os.getenv("API_RELOAD", "false").lower() == "true"
|
| 186 |
-
|
| 187 |
logger.info(f"Starting server on {host}:{port}")
|
| 188 |
-
|
| 189 |
uvicorn.run(
|
| 190 |
"app.main:app",
|
| 191 |
host=host,
|
|
|
|
| 3 |
Medical biomarker analysis API
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
| 6 |
import logging
|
| 7 |
+
import os
|
| 8 |
from contextlib import asynccontextmanager
|
| 9 |
|
| 10 |
from fastapi import FastAPI, Request, status
|
| 11 |
+
from fastapi.exceptions import RequestValidationError
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
from fastapi.responses import JSONResponse
|
|
|
|
| 14 |
|
| 15 |
from app import __version__
|
| 16 |
+
from app.routes import analyze, biomarkers, health
|
| 17 |
from app.services.ragbot import get_ragbot_service
|
| 18 |
|
|
|
|
| 19 |
# Configure logging
|
| 20 |
logging.basicConfig(
|
| 21 |
level=logging.INFO,
|
|
|
|
| 37 |
logger.info("=" * 70)
|
| 38 |
logger.info("Starting RagBot API Server")
|
| 39 |
logger.info("=" * 70)
|
| 40 |
+
|
| 41 |
# Startup: Initialize RagBot service
|
| 42 |
try:
|
| 43 |
ragbot_service = get_ragbot_service()
|
|
|
|
| 46 |
except Exception as e:
|
| 47 |
logger.error(f"Failed to initialize RagBot service: {e}")
|
| 48 |
logger.warning("API will start but health checks will fail")
|
| 49 |
+
|
| 50 |
logger.info("API server ready to accept requests")
|
| 51 |
logger.info("=" * 70)
|
| 52 |
+
|
| 53 |
yield # Server runs here
|
| 54 |
+
|
| 55 |
# Shutdown
|
| 56 |
logger.info("Shutting down RagBot API Server")
|
| 57 |
|
|
|
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
| 177 |
import uvicorn
|
| 178 |
+
|
| 179 |
# Get configuration from environment
|
| 180 |
host = os.getenv("API_HOST", "0.0.0.0")
|
| 181 |
port = int(os.getenv("API_PORT", "8000"))
|
| 182 |
reload = os.getenv("API_RELOAD", "false").lower() == "true"
|
| 183 |
+
|
| 184 |
logger.info(f"Starting server on {host}:{port}")
|
| 185 |
+
|
| 186 |
uvicorn.run(
|
| 187 |
"app.main:app",
|
| 188 |
host=host,
|
api/app/routes/analyze.py
CHANGED
|
@@ -4,19 +4,13 @@ Natural language and structured biomarker analysis
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
| 7 |
-
|
| 8 |
from fastapi import APIRouter, HTTPException, status
|
| 9 |
|
| 10 |
-
from app.models.schemas import
|
| 11 |
-
NaturalAnalysisRequest,
|
| 12 |
-
StructuredAnalysisRequest,
|
| 13 |
-
AnalysisResponse,
|
| 14 |
-
ErrorResponse
|
| 15 |
-
)
|
| 16 |
from app.services.extraction import extract_biomarkers, predict_disease_simple
|
| 17 |
from app.services.ragbot import get_ragbot_service
|
| 18 |
|
| 19 |
-
|
| 20 |
router = APIRouter(prefix="/api/v1", tags=["analysis"])
|
| 21 |
|
| 22 |
|
|
@@ -45,23 +39,23 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 45 |
|
| 46 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 47 |
"""
|
| 48 |
-
|
| 49 |
# Get services
|
| 50 |
ragbot_service = get_ragbot_service()
|
| 51 |
-
|
| 52 |
if not ragbot_service.is_ready():
|
| 53 |
raise HTTPException(
|
| 54 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 55 |
detail="RagBot service not initialized. Please try again in a moment."
|
| 56 |
)
|
| 57 |
-
|
| 58 |
# Extract biomarkers from natural language
|
| 59 |
ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| 60 |
biomarkers, extracted_context, error = extract_biomarkers(
|
| 61 |
request.message,
|
| 62 |
ollama_base_url=ollama_base_url
|
| 63 |
)
|
| 64 |
-
|
| 65 |
if error:
|
| 66 |
raise HTTPException(
|
| 67 |
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -72,7 +66,7 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 72 |
"suggestion": "Try: 'My glucose is 140 and HbA1c is 7.5'"
|
| 73 |
}
|
| 74 |
)
|
| 75 |
-
|
| 76 |
if not biomarkers:
|
| 77 |
raise HTTPException(
|
| 78 |
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -83,14 +77,14 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 83 |
"suggestion": "Include specific biomarker values like 'glucose is 140'"
|
| 84 |
}
|
| 85 |
)
|
| 86 |
-
|
| 87 |
# Merge extracted context with request context
|
| 88 |
patient_context = request.patient_context.model_dump() if request.patient_context else {}
|
| 89 |
patient_context.update(extracted_context)
|
| 90 |
-
|
| 91 |
# Predict disease (simple rule-based for now)
|
| 92 |
model_prediction = predict_disease_simple(biomarkers)
|
| 93 |
-
|
| 94 |
try:
|
| 95 |
# Run full analysis
|
| 96 |
response = ragbot_service.analyze(
|
|
@@ -99,15 +93,15 @@ async def analyze_natural(request: NaturalAnalysisRequest):
|
|
| 99 |
model_prediction=model_prediction,
|
| 100 |
extracted_biomarkers=biomarkers # Keep original extraction
|
| 101 |
)
|
| 102 |
-
|
| 103 |
return response
|
| 104 |
-
|
| 105 |
except Exception as e:
|
| 106 |
raise HTTPException(
|
| 107 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 108 |
detail={
|
| 109 |
"error_code": "ANALYSIS_FAILED",
|
| 110 |
-
"message": f"Analysis workflow failed: {
|
| 111 |
"biomarkers_received": biomarkers
|
| 112 |
}
|
| 113 |
)
|
|
@@ -145,16 +139,16 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 145 |
Use this endpoint when you already have structured biomarker data.
|
| 146 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 147 |
"""
|
| 148 |
-
|
| 149 |
# Get services
|
| 150 |
ragbot_service = get_ragbot_service()
|
| 151 |
-
|
| 152 |
if not ragbot_service.is_ready():
|
| 153 |
raise HTTPException(
|
| 154 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 155 |
detail="RagBot service not initialized. Please try again in a moment."
|
| 156 |
)
|
| 157 |
-
|
| 158 |
# Validate biomarkers
|
| 159 |
if not request.biomarkers:
|
| 160 |
raise HTTPException(
|
|
@@ -165,13 +159,13 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 165 |
"suggestion": "Provide at least one biomarker with a numeric value"
|
| 166 |
}
|
| 167 |
)
|
| 168 |
-
|
| 169 |
# Patient context
|
| 170 |
patient_context = request.patient_context.model_dump() if request.patient_context else {}
|
| 171 |
-
|
| 172 |
# Predict disease
|
| 173 |
model_prediction = predict_disease_simple(request.biomarkers)
|
| 174 |
-
|
| 175 |
try:
|
| 176 |
# Run full analysis
|
| 177 |
response = ragbot_service.analyze(
|
|
@@ -180,15 +174,15 @@ async def analyze_structured(request: StructuredAnalysisRequest):
|
|
| 180 |
model_prediction=model_prediction,
|
| 181 |
extracted_biomarkers=None # No extraction for structured input
|
| 182 |
)
|
| 183 |
-
|
| 184 |
return response
|
| 185 |
-
|
| 186 |
except Exception as e:
|
| 187 |
raise HTTPException(
|
| 188 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 189 |
detail={
|
| 190 |
"error_code": "ANALYSIS_FAILED",
|
| 191 |
-
"message": f"Analysis workflow failed: {
|
| 192 |
"biomarkers_received": request.biomarkers
|
| 193 |
}
|
| 194 |
)
|
|
@@ -211,16 +205,16 @@ async def get_example():
|
|
| 211 |
|
| 212 |
Same as CLI chatbot 'example' command.
|
| 213 |
"""
|
| 214 |
-
|
| 215 |
# Get services
|
| 216 |
ragbot_service = get_ragbot_service()
|
| 217 |
-
|
| 218 |
if not ragbot_service.is_ready():
|
| 219 |
raise HTTPException(
|
| 220 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 221 |
detail="RagBot service not initialized. Please try again in a moment."
|
| 222 |
)
|
| 223 |
-
|
| 224 |
# Example biomarkers (Type 2 Diabetes patient)
|
| 225 |
biomarkers = {
|
| 226 |
"Glucose": 185.0,
|
|
@@ -235,14 +229,14 @@ async def get_example():
|
|
| 235 |
"Systolic Blood Pressure": 142.0,
|
| 236 |
"Diastolic Blood Pressure": 88.0
|
| 237 |
}
|
| 238 |
-
|
| 239 |
patient_context = {
|
| 240 |
"age": 52,
|
| 241 |
"gender": "male",
|
| 242 |
"bmi": 31.2,
|
| 243 |
"patient_id": "EXAMPLE-001"
|
| 244 |
}
|
| 245 |
-
|
| 246 |
model_prediction = {
|
| 247 |
"disease": "Diabetes",
|
| 248 |
"confidence": 0.87,
|
|
@@ -254,7 +248,7 @@ async def get_example():
|
|
| 254 |
"Thrombocytopenia": 0.01
|
| 255 |
}
|
| 256 |
}
|
| 257 |
-
|
| 258 |
try:
|
| 259 |
# Run analysis
|
| 260 |
response = ragbot_service.analyze(
|
|
@@ -263,14 +257,14 @@ async def get_example():
|
|
| 263 |
model_prediction=model_prediction,
|
| 264 |
extracted_biomarkers=None
|
| 265 |
)
|
| 266 |
-
|
| 267 |
return response
|
| 268 |
-
|
| 269 |
except Exception as e:
|
| 270 |
raise HTTPException(
|
| 271 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 272 |
detail={
|
| 273 |
"error_code": "EXAMPLE_FAILED",
|
| 274 |
-
"message": f"Example analysis failed: {
|
| 275 |
}
|
| 276 |
)
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
| 7 |
+
|
| 8 |
from fastapi import APIRouter, HTTPException, status
|
| 9 |
|
| 10 |
+
from app.models.schemas import AnalysisResponse, NaturalAnalysisRequest, StructuredAnalysisRequest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from app.services.extraction import extract_biomarkers, predict_disease_simple
|
| 12 |
from app.services.ragbot import get_ragbot_service
|
| 13 |
|
|
|
|
| 14 |
router = APIRouter(prefix="/api/v1", tags=["analysis"])
|
| 15 |
|
| 16 |
|
|
|
|
| 39 |
|
| 40 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 41 |
"""
|
| 42 |
+
|
| 43 |
# Get services
|
| 44 |
ragbot_service = get_ragbot_service()
|
| 45 |
+
|
| 46 |
if not ragbot_service.is_ready():
|
| 47 |
raise HTTPException(
|
| 48 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 49 |
detail="RagBot service not initialized. Please try again in a moment."
|
| 50 |
)
|
| 51 |
+
|
| 52 |
# Extract biomarkers from natural language
|
| 53 |
ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| 54 |
biomarkers, extracted_context, error = extract_biomarkers(
|
| 55 |
request.message,
|
| 56 |
ollama_base_url=ollama_base_url
|
| 57 |
)
|
| 58 |
+
|
| 59 |
if error:
|
| 60 |
raise HTTPException(
|
| 61 |
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
| 66 |
"suggestion": "Try: 'My glucose is 140 and HbA1c is 7.5'"
|
| 67 |
}
|
| 68 |
)
|
| 69 |
+
|
| 70 |
if not biomarkers:
|
| 71 |
raise HTTPException(
|
| 72 |
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
| 77 |
"suggestion": "Include specific biomarker values like 'glucose is 140'"
|
| 78 |
}
|
| 79 |
)
|
| 80 |
+
|
| 81 |
# Merge extracted context with request context
|
| 82 |
patient_context = request.patient_context.model_dump() if request.patient_context else {}
|
| 83 |
patient_context.update(extracted_context)
|
| 84 |
+
|
| 85 |
# Predict disease (simple rule-based for now)
|
| 86 |
model_prediction = predict_disease_simple(biomarkers)
|
| 87 |
+
|
| 88 |
try:
|
| 89 |
# Run full analysis
|
| 90 |
response = ragbot_service.analyze(
|
|
|
|
| 93 |
model_prediction=model_prediction,
|
| 94 |
extracted_biomarkers=biomarkers # Keep original extraction
|
| 95 |
)
|
| 96 |
+
|
| 97 |
return response
|
| 98 |
+
|
| 99 |
except Exception as e:
|
| 100 |
raise HTTPException(
|
| 101 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 102 |
detail={
|
| 103 |
"error_code": "ANALYSIS_FAILED",
|
| 104 |
+
"message": f"Analysis workflow failed: {e!s}",
|
| 105 |
"biomarkers_received": biomarkers
|
| 106 |
}
|
| 107 |
)
|
|
|
|
| 139 |
Use this endpoint when you already have structured biomarker data.
|
| 140 |
Returns full detailed analysis with all agent outputs, citations, recommendations.
|
| 141 |
"""
|
| 142 |
+
|
| 143 |
# Get services
|
| 144 |
ragbot_service = get_ragbot_service()
|
| 145 |
+
|
| 146 |
if not ragbot_service.is_ready():
|
| 147 |
raise HTTPException(
|
| 148 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 149 |
detail="RagBot service not initialized. Please try again in a moment."
|
| 150 |
)
|
| 151 |
+
|
| 152 |
# Validate biomarkers
|
| 153 |
if not request.biomarkers:
|
| 154 |
raise HTTPException(
|
|
|
|
| 159 |
"suggestion": "Provide at least one biomarker with a numeric value"
|
| 160 |
}
|
| 161 |
)
|
| 162 |
+
|
| 163 |
# Patient context
|
| 164 |
patient_context = request.patient_context.model_dump() if request.patient_context else {}
|
| 165 |
+
|
| 166 |
# Predict disease
|
| 167 |
model_prediction = predict_disease_simple(request.biomarkers)
|
| 168 |
+
|
| 169 |
try:
|
| 170 |
# Run full analysis
|
| 171 |
response = ragbot_service.analyze(
|
|
|
|
| 174 |
model_prediction=model_prediction,
|
| 175 |
extracted_biomarkers=None # No extraction for structured input
|
| 176 |
)
|
| 177 |
+
|
| 178 |
return response
|
| 179 |
+
|
| 180 |
except Exception as e:
|
| 181 |
raise HTTPException(
|
| 182 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 183 |
detail={
|
| 184 |
"error_code": "ANALYSIS_FAILED",
|
| 185 |
+
"message": f"Analysis workflow failed: {e!s}",
|
| 186 |
"biomarkers_received": request.biomarkers
|
| 187 |
}
|
| 188 |
)
|
|
|
|
| 205 |
|
| 206 |
Same as CLI chatbot 'example' command.
|
| 207 |
"""
|
| 208 |
+
|
| 209 |
# Get services
|
| 210 |
ragbot_service = get_ragbot_service()
|
| 211 |
+
|
| 212 |
if not ragbot_service.is_ready():
|
| 213 |
raise HTTPException(
|
| 214 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 215 |
detail="RagBot service not initialized. Please try again in a moment."
|
| 216 |
)
|
| 217 |
+
|
| 218 |
# Example biomarkers (Type 2 Diabetes patient)
|
| 219 |
biomarkers = {
|
| 220 |
"Glucose": 185.0,
|
|
|
|
| 229 |
"Systolic Blood Pressure": 142.0,
|
| 230 |
"Diastolic Blood Pressure": 88.0
|
| 231 |
}
|
| 232 |
+
|
| 233 |
patient_context = {
|
| 234 |
"age": 52,
|
| 235 |
"gender": "male",
|
| 236 |
"bmi": 31.2,
|
| 237 |
"patient_id": "EXAMPLE-001"
|
| 238 |
}
|
| 239 |
+
|
| 240 |
model_prediction = {
|
| 241 |
"disease": "Diabetes",
|
| 242 |
"confidence": 0.87,
|
|
|
|
| 248 |
"Thrombocytopenia": 0.01
|
| 249 |
}
|
| 250 |
}
|
| 251 |
+
|
| 252 |
try:
|
| 253 |
# Run analysis
|
| 254 |
response = ragbot_service.analyze(
|
|
|
|
| 257 |
model_prediction=model_prediction,
|
| 258 |
extracted_biomarkers=None
|
| 259 |
)
|
| 260 |
+
|
| 261 |
return response
|
| 262 |
+
|
| 263 |
except Exception as e:
|
| 264 |
raise HTTPException(
|
| 265 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 266 |
detail={
|
| 267 |
"error_code": "EXAMPLE_FAILED",
|
| 268 |
+
"message": f"Example analysis failed: {e!s}"
|
| 269 |
}
|
| 270 |
)
|
api/app/routes/biomarkers.py
CHANGED
|
@@ -3,13 +3,12 @@ Biomarkers List Endpoint
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import json
|
| 6 |
-
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
from datetime import datetime
|
| 9 |
-
from
|
| 10 |
|
| 11 |
-
from
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
router = APIRouter(prefix="/api/v1", tags=["biomarkers"])
|
| 15 |
|
|
@@ -30,22 +29,22 @@ async def list_biomarkers():
|
|
| 30 |
- Understanding what biomarkers can be analyzed
|
| 31 |
- Getting reference ranges for display
|
| 32 |
"""
|
| 33 |
-
|
| 34 |
try:
|
| 35 |
# Load biomarker references
|
| 36 |
config_path = Path(__file__).parent.parent.parent.parent / "config" / "biomarker_references.json"
|
| 37 |
-
|
| 38 |
-
with open(config_path
|
| 39 |
config_data = json.load(f)
|
| 40 |
-
|
| 41 |
biomarkers_data = config_data.get("biomarkers", {})
|
| 42 |
-
|
| 43 |
biomarkers_list = []
|
| 44 |
-
|
| 45 |
for name, info in biomarkers_data.items():
|
| 46 |
# Parse reference range
|
| 47 |
normal_range_data = info.get("normal_range", {})
|
| 48 |
-
|
| 49 |
if "male" in normal_range_data or "female" in normal_range_data:
|
| 50 |
# Gender-specific ranges
|
| 51 |
reference_range = BiomarkerReferenceRange(
|
|
@@ -62,7 +61,7 @@ async def list_biomarkers():
|
|
| 62 |
male=None,
|
| 63 |
female=None
|
| 64 |
)
|
| 65 |
-
|
| 66 |
biomarker_info = BiomarkerInfo(
|
| 67 |
name=name,
|
| 68 |
unit=info.get("unit", ""),
|
|
@@ -73,23 +72,23 @@ async def list_biomarkers():
|
|
| 73 |
description=info.get("description", ""),
|
| 74 |
clinical_significance=info.get("clinical_significance", {})
|
| 75 |
)
|
| 76 |
-
|
| 77 |
biomarkers_list.append(biomarker_info)
|
| 78 |
-
|
| 79 |
return BiomarkersListResponse(
|
| 80 |
biomarkers=biomarkers_list,
|
| 81 |
total_count=len(biomarkers_list),
|
| 82 |
timestamp=datetime.now().isoformat()
|
| 83 |
)
|
| 84 |
-
|
| 85 |
except FileNotFoundError:
|
| 86 |
raise HTTPException(
|
| 87 |
status_code=500,
|
| 88 |
detail="Biomarker configuration file not found"
|
| 89 |
)
|
| 90 |
-
|
| 91 |
except Exception as e:
|
| 92 |
raise HTTPException(
|
| 93 |
status_code=500,
|
| 94 |
-
detail=f"Failed to load biomarkers: {
|
| 95 |
)
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import json
|
|
|
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
|
| 9 |
+
from fastapi import APIRouter, HTTPException
|
| 10 |
|
| 11 |
+
from app.models.schemas import BiomarkerInfo, BiomarkerReferenceRange, BiomarkersListResponse
|
| 12 |
|
| 13 |
router = APIRouter(prefix="/api/v1", tags=["biomarkers"])
|
| 14 |
|
|
|
|
| 29 |
- Understanding what biomarkers can be analyzed
|
| 30 |
- Getting reference ranges for display
|
| 31 |
"""
|
| 32 |
+
|
| 33 |
try:
|
| 34 |
# Load biomarker references
|
| 35 |
config_path = Path(__file__).parent.parent.parent.parent / "config" / "biomarker_references.json"
|
| 36 |
+
|
| 37 |
+
with open(config_path) as f:
|
| 38 |
config_data = json.load(f)
|
| 39 |
+
|
| 40 |
biomarkers_data = config_data.get("biomarkers", {})
|
| 41 |
+
|
| 42 |
biomarkers_list = []
|
| 43 |
+
|
| 44 |
for name, info in biomarkers_data.items():
|
| 45 |
# Parse reference range
|
| 46 |
normal_range_data = info.get("normal_range", {})
|
| 47 |
+
|
| 48 |
if "male" in normal_range_data or "female" in normal_range_data:
|
| 49 |
# Gender-specific ranges
|
| 50 |
reference_range = BiomarkerReferenceRange(
|
|
|
|
| 61 |
male=None,
|
| 62 |
female=None
|
| 63 |
)
|
| 64 |
+
|
| 65 |
biomarker_info = BiomarkerInfo(
|
| 66 |
name=name,
|
| 67 |
unit=info.get("unit", ""),
|
|
|
|
| 72 |
description=info.get("description", ""),
|
| 73 |
clinical_significance=info.get("clinical_significance", {})
|
| 74 |
)
|
| 75 |
+
|
| 76 |
biomarkers_list.append(biomarker_info)
|
| 77 |
+
|
| 78 |
return BiomarkersListResponse(
|
| 79 |
biomarkers=biomarkers_list,
|
| 80 |
total_count=len(biomarkers_list),
|
| 81 |
timestamp=datetime.now().isoformat()
|
| 82 |
)
|
| 83 |
+
|
| 84 |
except FileNotFoundError:
|
| 85 |
raise HTTPException(
|
| 86 |
status_code=500,
|
| 87 |
detail="Biomarker configuration file not found"
|
| 88 |
)
|
| 89 |
+
|
| 90 |
except Exception as e:
|
| 91 |
raise HTTPException(
|
| 92 |
status_code=500,
|
| 93 |
+
detail=f"Failed to load biomarkers: {e!s}"
|
| 94 |
)
|
api/app/routes/health.py
CHANGED
|
@@ -2,16 +2,13 @@
|
|
| 2 |
Health Check Endpoint
|
| 3 |
"""
|
| 4 |
|
| 5 |
-
import os
|
| 6 |
-
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
from datetime import datetime
|
| 9 |
-
from fastapi import APIRouter, HTTPException
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
from app.models.schemas import HealthResponse
|
| 12 |
from app.services.ragbot import get_ragbot_service
|
| 13 |
-
from app import __version__
|
| 14 |
-
|
| 15 |
|
| 16 |
router = APIRouter(prefix="/api/v1", tags=["health"])
|
| 17 |
|
|
@@ -30,16 +27,16 @@ async def health_check():
|
|
| 30 |
Returns health status with component details.
|
| 31 |
"""
|
| 32 |
ragbot_service = get_ragbot_service()
|
| 33 |
-
|
| 34 |
# Check LLM API connection
|
| 35 |
llm_status = "disconnected"
|
| 36 |
available_models = []
|
| 37 |
-
|
| 38 |
try:
|
| 39 |
-
from src.llm_config import
|
| 40 |
-
|
| 41 |
test_llm = get_chat_model(temperature=0.0)
|
| 42 |
-
|
| 43 |
# Try a simple test
|
| 44 |
response = test_llm.invoke("Say OK")
|
| 45 |
if response:
|
|
@@ -50,13 +47,13 @@ async def health_check():
|
|
| 50 |
available_models = ["gemini-2.0-flash (Google)"]
|
| 51 |
else:
|
| 52 |
available_models = ["llama3.1:8b (Ollama)"]
|
| 53 |
-
|
| 54 |
except Exception as e:
|
| 55 |
llm_status = f"error: {str(e)[:100]}"
|
| 56 |
-
|
| 57 |
# Check vector store
|
| 58 |
vector_store_loaded = ragbot_service.is_ready()
|
| 59 |
-
|
| 60 |
# Determine overall status
|
| 61 |
if llm_status == "connected" and vector_store_loaded:
|
| 62 |
overall_status = "healthy"
|
|
@@ -64,7 +61,7 @@ async def health_check():
|
|
| 64 |
overall_status = "degraded"
|
| 65 |
else:
|
| 66 |
overall_status = "unhealthy"
|
| 67 |
-
|
| 68 |
return HealthResponse(
|
| 69 |
status=overall_status,
|
| 70 |
timestamp=datetime.now().isoformat(),
|
|
|
|
| 2 |
Health Check Endpoint
|
| 3 |
"""
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
from datetime import datetime
|
|
|
|
| 6 |
|
| 7 |
+
from fastapi import APIRouter
|
| 8 |
+
|
| 9 |
+
from app import __version__
|
| 10 |
from app.models.schemas import HealthResponse
|
| 11 |
from app.services.ragbot import get_ragbot_service
|
|
|
|
|
|
|
| 12 |
|
| 13 |
router = APIRouter(prefix="/api/v1", tags=["health"])
|
| 14 |
|
|
|
|
| 27 |
Returns health status with component details.
|
| 28 |
"""
|
| 29 |
ragbot_service = get_ragbot_service()
|
| 30 |
+
|
| 31 |
# Check LLM API connection
|
| 32 |
llm_status = "disconnected"
|
| 33 |
available_models = []
|
| 34 |
+
|
| 35 |
try:
|
| 36 |
+
from src.llm_config import DEFAULT_LLM_PROVIDER, get_chat_model
|
| 37 |
+
|
| 38 |
test_llm = get_chat_model(temperature=0.0)
|
| 39 |
+
|
| 40 |
# Try a simple test
|
| 41 |
response = test_llm.invoke("Say OK")
|
| 42 |
if response:
|
|
|
|
| 47 |
available_models = ["gemini-2.0-flash (Google)"]
|
| 48 |
else:
|
| 49 |
available_models = ["llama3.1:8b (Ollama)"]
|
| 50 |
+
|
| 51 |
except Exception as e:
|
| 52 |
llm_status = f"error: {str(e)[:100]}"
|
| 53 |
+
|
| 54 |
# Check vector store
|
| 55 |
vector_store_loaded = ragbot_service.is_ready()
|
| 56 |
+
|
| 57 |
# Determine overall status
|
| 58 |
if llm_status == "connected" and vector_store_loaded:
|
| 59 |
overall_status = "healthy"
|
|
|
|
| 61 |
overall_status = "degraded"
|
| 62 |
else:
|
| 63 |
overall_status = "unhealthy"
|
| 64 |
+
|
| 65 |
return HealthResponse(
|
| 66 |
status=overall_status,
|
| 67 |
timestamp=datetime.now().isoformat(),
|
api/app/services/extraction.py
CHANGED
|
@@ -6,7 +6,7 @@ Extracts biomarker values from natural language text using LLM
|
|
| 6 |
import json
|
| 7 |
import sys
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import
|
| 10 |
|
| 11 |
# Ensure project root is in path for src imports
|
| 12 |
_project_root = str(Path(__file__).parent.parent.parent.parent)
|
|
@@ -14,10 +14,10 @@ if _project_root not in sys.path:
|
|
| 14 |
sys.path.insert(0, _project_root)
|
| 15 |
|
| 16 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
| 17 |
from src.biomarker_normalization import normalize_biomarker_name
|
| 18 |
from src.llm_config import get_chat_model
|
| 19 |
|
| 20 |
-
|
| 21 |
# ============================================================================
|
| 22 |
# EXTRACTION PROMPT
|
| 23 |
# ============================================================================
|
|
@@ -54,7 +54,7 @@ If you cannot find any biomarkers, return {{"biomarkers": {{}}, "patient_context
|
|
| 54 |
# EXTRACTION HELPERS
|
| 55 |
# ============================================================================
|
| 56 |
|
| 57 |
-
def _parse_llm_json(content: str) ->
|
| 58 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 59 |
text = content.strip()
|
| 60 |
|
|
@@ -78,9 +78,9 @@ def _parse_llm_json(content: str) -> Dict[str, Any]:
|
|
| 78 |
# ============================================================================
|
| 79 |
|
| 80 |
def extract_biomarkers(
|
| 81 |
-
user_message: str,
|
| 82 |
ollama_base_url: str = None # Kept for backward compatibility, ignored
|
| 83 |
-
) ->
|
| 84 |
"""
|
| 85 |
Extract biomarker values from natural language using LLM.
|
| 86 |
|
|
@@ -102,18 +102,18 @@ def extract_biomarkers(
|
|
| 102 |
try:
|
| 103 |
# Initialize LLM (uses Groq/Gemini by default - FREE)
|
| 104 |
llm = get_chat_model(temperature=0.0)
|
| 105 |
-
|
| 106 |
prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
|
| 107 |
chain = prompt | llm
|
| 108 |
-
|
| 109 |
# Invoke LLM
|
| 110 |
response = chain.invoke({"user_message": user_message})
|
| 111 |
content = response.content.strip()
|
| 112 |
-
|
| 113 |
extracted = _parse_llm_json(content)
|
| 114 |
biomarkers = extracted.get("biomarkers", {})
|
| 115 |
patient_context = extracted.get("patient_context", {})
|
| 116 |
-
|
| 117 |
# Normalize biomarker names and convert to float
|
| 118 |
normalized = {}
|
| 119 |
for key, value in biomarkers.items():
|
|
@@ -123,27 +123,27 @@ def extract_biomarkers(
|
|
| 123 |
except (ValueError, TypeError):
|
| 124 |
# Skip invalid values
|
| 125 |
continue
|
| 126 |
-
|
| 127 |
# Clean up patient context (remove null values)
|
| 128 |
patient_context = {k: v for k, v in patient_context.items() if v is not None}
|
| 129 |
-
|
| 130 |
if not normalized:
|
| 131 |
return {}, patient_context, "No biomarkers found in the input"
|
| 132 |
-
|
| 133 |
return normalized, patient_context, ""
|
| 134 |
-
|
| 135 |
except json.JSONDecodeError as e:
|
| 136 |
-
return {}, {}, f"Failed to parse LLM response as JSON: {
|
| 137 |
-
|
| 138 |
except Exception as e:
|
| 139 |
-
return {}, {}, f"Extraction failed: {
|
| 140 |
|
| 141 |
|
| 142 |
# ============================================================================
|
| 143 |
# SIMPLE DISEASE PREDICTION (Fallback)
|
| 144 |
# ============================================================================
|
| 145 |
|
| 146 |
-
def predict_disease_simple(biomarkers:
|
| 147 |
"""
|
| 148 |
Simple rule-based disease prediction based on key biomarkers.
|
| 149 |
Used as a fallback when no ML model is available.
|
|
@@ -161,15 +161,15 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 161 |
"Thrombocytopenia": 0.0,
|
| 162 |
"Thalassemia": 0.0
|
| 163 |
}
|
| 164 |
-
|
| 165 |
# Helper: check both abbreviated and normalized biomarker names
|
| 166 |
# Returns None when biomarker is not present (avoids false triggers)
|
| 167 |
def _get(name, *alt_names):
|
| 168 |
-
val = biomarkers.get(name
|
| 169 |
if val is not None:
|
| 170 |
return val
|
| 171 |
for alt in alt_names:
|
| 172 |
-
val = biomarkers.get(alt
|
| 173 |
if val is not None:
|
| 174 |
return val
|
| 175 |
return None
|
|
@@ -183,7 +183,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 183 |
scores["Diabetes"] += 0.2
|
| 184 |
if hba1c is not None and hba1c >= 6.5:
|
| 185 |
scores["Diabetes"] += 0.5
|
| 186 |
-
|
| 187 |
# Anemia indicators
|
| 188 |
hemoglobin = _get("Hemoglobin")
|
| 189 |
mcv = _get("Mean Corpuscular Volume", "MCV")
|
|
@@ -193,7 +193,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 193 |
scores["Anemia"] += 0.2
|
| 194 |
if mcv is not None and mcv < 80:
|
| 195 |
scores["Anemia"] += 0.2
|
| 196 |
-
|
| 197 |
# Heart disease indicators
|
| 198 |
cholesterol = _get("Cholesterol")
|
| 199 |
troponin = _get("Troponin")
|
|
@@ -204,32 +204,32 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 204 |
scores["Heart Disease"] += 0.6
|
| 205 |
if ldl is not None and ldl > 190:
|
| 206 |
scores["Heart Disease"] += 0.2
|
| 207 |
-
|
| 208 |
# Thrombocytopenia indicators
|
| 209 |
platelets = _get("Platelets")
|
| 210 |
if platelets is not None and platelets < 150000:
|
| 211 |
scores["Thrombocytopenia"] += 0.6
|
| 212 |
if platelets is not None and platelets < 50000:
|
| 213 |
scores["Thrombocytopenia"] += 0.3
|
| 214 |
-
|
| 215 |
# Thalassemia indicators (simplified)
|
| 216 |
if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
|
| 217 |
scores["Thalassemia"] += 0.4
|
| 218 |
-
|
| 219 |
# Find top prediction
|
| 220 |
top_disease = max(scores, key=scores.get)
|
| 221 |
confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
|
| 222 |
|
| 223 |
if confidence == 0.0:
|
| 224 |
top_disease = "Undetermined"
|
| 225 |
-
|
| 226 |
# Normalize probabilities to sum to 1.0
|
| 227 |
total = sum(scores.values())
|
| 228 |
if total > 0:
|
| 229 |
probabilities = {k: v / total for k, v in scores.items()}
|
| 230 |
else:
|
| 231 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 232 |
-
|
| 233 |
return {
|
| 234 |
"disease": top_disease,
|
| 235 |
"confidence": confidence,
|
|
|
|
| 6 |
import json
|
| 7 |
import sys
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
|
| 11 |
# Ensure project root is in path for src imports
|
| 12 |
_project_root = str(Path(__file__).parent.parent.parent.parent)
|
|
|
|
| 14 |
sys.path.insert(0, _project_root)
|
| 15 |
|
| 16 |
from langchain_core.prompts import ChatPromptTemplate
|
| 17 |
+
|
| 18 |
from src.biomarker_normalization import normalize_biomarker_name
|
| 19 |
from src.llm_config import get_chat_model
|
| 20 |
|
|
|
|
| 21 |
# ============================================================================
|
| 22 |
# EXTRACTION PROMPT
|
| 23 |
# ============================================================================
|
|
|
|
| 54 |
# EXTRACTION HELPERS
|
| 55 |
# ============================================================================
|
| 56 |
|
| 57 |
+
def _parse_llm_json(content: str) -> dict[str, Any]:
|
| 58 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 59 |
text = content.strip()
|
| 60 |
|
|
|
|
| 78 |
# ============================================================================
|
| 79 |
|
| 80 |
def extract_biomarkers(
|
| 81 |
+
user_message: str,
|
| 82 |
ollama_base_url: str = None # Kept for backward compatibility, ignored
|
| 83 |
+
) -> tuple[dict[str, float], dict[str, Any], str]:
|
| 84 |
"""
|
| 85 |
Extract biomarker values from natural language using LLM.
|
| 86 |
|
|
|
|
| 102 |
try:
|
| 103 |
# Initialize LLM (uses Groq/Gemini by default - FREE)
|
| 104 |
llm = get_chat_model(temperature=0.0)
|
| 105 |
+
|
| 106 |
prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
|
| 107 |
chain = prompt | llm
|
| 108 |
+
|
| 109 |
# Invoke LLM
|
| 110 |
response = chain.invoke({"user_message": user_message})
|
| 111 |
content = response.content.strip()
|
| 112 |
+
|
| 113 |
extracted = _parse_llm_json(content)
|
| 114 |
biomarkers = extracted.get("biomarkers", {})
|
| 115 |
patient_context = extracted.get("patient_context", {})
|
| 116 |
+
|
| 117 |
# Normalize biomarker names and convert to float
|
| 118 |
normalized = {}
|
| 119 |
for key, value in biomarkers.items():
|
|
|
|
| 123 |
except (ValueError, TypeError):
|
| 124 |
# Skip invalid values
|
| 125 |
continue
|
| 126 |
+
|
| 127 |
# Clean up patient context (remove null values)
|
| 128 |
patient_context = {k: v for k, v in patient_context.items() if v is not None}
|
| 129 |
+
|
| 130 |
if not normalized:
|
| 131 |
return {}, patient_context, "No biomarkers found in the input"
|
| 132 |
+
|
| 133 |
return normalized, patient_context, ""
|
| 134 |
+
|
| 135 |
except json.JSONDecodeError as e:
|
| 136 |
+
return {}, {}, f"Failed to parse LLM response as JSON: {e!s}"
|
| 137 |
+
|
| 138 |
except Exception as e:
|
| 139 |
+
return {}, {}, f"Extraction failed: {e!s}"
|
| 140 |
|
| 141 |
|
| 142 |
# ============================================================================
|
| 143 |
# SIMPLE DISEASE PREDICTION (Fallback)
|
| 144 |
# ============================================================================
|
| 145 |
|
| 146 |
+
def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 147 |
"""
|
| 148 |
Simple rule-based disease prediction based on key biomarkers.
|
| 149 |
Used as a fallback when no ML model is available.
|
|
|
|
| 161 |
"Thrombocytopenia": 0.0,
|
| 162 |
"Thalassemia": 0.0
|
| 163 |
}
|
| 164 |
+
|
| 165 |
# Helper: check both abbreviated and normalized biomarker names
|
| 166 |
# Returns None when biomarker is not present (avoids false triggers)
|
| 167 |
def _get(name, *alt_names):
|
| 168 |
+
val = biomarkers.get(name)
|
| 169 |
if val is not None:
|
| 170 |
return val
|
| 171 |
for alt in alt_names:
|
| 172 |
+
val = biomarkers.get(alt)
|
| 173 |
if val is not None:
|
| 174 |
return val
|
| 175 |
return None
|
|
|
|
| 183 |
scores["Diabetes"] += 0.2
|
| 184 |
if hba1c is not None and hba1c >= 6.5:
|
| 185 |
scores["Diabetes"] += 0.5
|
| 186 |
+
|
| 187 |
# Anemia indicators
|
| 188 |
hemoglobin = _get("Hemoglobin")
|
| 189 |
mcv = _get("Mean Corpuscular Volume", "MCV")
|
|
|
|
| 193 |
scores["Anemia"] += 0.2
|
| 194 |
if mcv is not None and mcv < 80:
|
| 195 |
scores["Anemia"] += 0.2
|
| 196 |
+
|
| 197 |
# Heart disease indicators
|
| 198 |
cholesterol = _get("Cholesterol")
|
| 199 |
troponin = _get("Troponin")
|
|
|
|
| 204 |
scores["Heart Disease"] += 0.6
|
| 205 |
if ldl is not None and ldl > 190:
|
| 206 |
scores["Heart Disease"] += 0.2
|
| 207 |
+
|
| 208 |
# Thrombocytopenia indicators
|
| 209 |
platelets = _get("Platelets")
|
| 210 |
if platelets is not None and platelets < 150000:
|
| 211 |
scores["Thrombocytopenia"] += 0.6
|
| 212 |
if platelets is not None and platelets < 50000:
|
| 213 |
scores["Thrombocytopenia"] += 0.3
|
| 214 |
+
|
| 215 |
# Thalassemia indicators (simplified)
|
| 216 |
if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
|
| 217 |
scores["Thalassemia"] += 0.4
|
| 218 |
+
|
| 219 |
# Find top prediction
|
| 220 |
top_disease = max(scores, key=scores.get)
|
| 221 |
confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
|
| 222 |
|
| 223 |
if confidence == 0.0:
|
| 224 |
top_disease = "Undetermined"
|
| 225 |
+
|
| 226 |
# Normalize probabilities to sum to 1.0
|
| 227 |
total = sum(scores.values())
|
| 228 |
if total > 0:
|
| 229 |
probabilities = {k: v / total for k, v in scores.items()}
|
| 230 |
else:
|
| 231 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 232 |
+
|
| 233 |
return {
|
| 234 |
"disease": top_disease,
|
| 235 |
"confidence": confidence,
|
api/app/services/ragbot.py
CHANGED
|
@@ -6,22 +6,29 @@ Wraps the RagBot workflow and formats comprehensive responses
|
|
| 6 |
import sys
|
| 7 |
import time
|
| 8 |
import uuid
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Dict, Any
|
| 11 |
from datetime import datetime
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Ensure project root is in path for src imports
|
| 14 |
_project_root = str(Path(__file__).parent.parent.parent.parent)
|
| 15 |
if _project_root not in sys.path:
|
| 16 |
sys.path.insert(0, _project_root)
|
| 17 |
|
| 18 |
-
from src.workflow import create_guild
|
| 19 |
-
from src.state import PatientInput
|
| 20 |
from app.models.schemas import (
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class RagBotService:
|
|
@@ -29,65 +36,65 @@ class RagBotService:
|
|
| 29 |
Service class to manage RagBot workflow lifecycle.
|
| 30 |
Initializes once, then handles multiple analysis requests.
|
| 31 |
"""
|
| 32 |
-
|
| 33 |
def __init__(self):
|
| 34 |
"""Initialize the workflow (loads vector store, models, etc.)"""
|
| 35 |
self.guild = None
|
| 36 |
self.initialized = False
|
| 37 |
self.init_time = None
|
| 38 |
-
|
| 39 |
def initialize(self):
|
| 40 |
"""Initialize the Clinical Insight Guild (expensive operation)"""
|
| 41 |
if self.initialized:
|
| 42 |
return
|
| 43 |
-
|
| 44 |
print("INFO: Initializing RagBot workflow...")
|
| 45 |
start_time = time.time()
|
| 46 |
-
|
| 47 |
import os
|
| 48 |
-
|
| 49 |
try:
|
| 50 |
# Set working directory via environment so vector store paths resolve
|
| 51 |
# without a process-global os.chdir() (which is thread-unsafe).
|
| 52 |
ragbot_root = Path(__file__).parent.parent.parent.parent
|
| 53 |
os.environ["RAGBOT_ROOT"] = str(ragbot_root)
|
| 54 |
print(f"INFO: Project root: {ragbot_root}")
|
| 55 |
-
|
| 56 |
# Temporarily chdir only during initialization (single-threaded at startup)
|
| 57 |
original_dir = os.getcwd()
|
| 58 |
os.chdir(ragbot_root)
|
| 59 |
-
|
| 60 |
self.guild = create_guild()
|
| 61 |
self.initialized = True
|
| 62 |
self.init_time = datetime.now()
|
| 63 |
-
|
| 64 |
elapsed = (time.time() - start_time) * 1000
|
| 65 |
print(f"OK: RagBot initialized successfully ({elapsed:.0f}ms)")
|
| 66 |
-
|
| 67 |
except Exception as e:
|
| 68 |
print(f"ERROR: Failed to initialize RagBot: {e}")
|
| 69 |
raise
|
| 70 |
-
|
| 71 |
finally:
|
| 72 |
# Restore original directory
|
| 73 |
os.chdir(original_dir)
|
| 74 |
-
|
| 75 |
def get_uptime_seconds(self) -> float:
|
| 76 |
"""Get API uptime in seconds"""
|
| 77 |
if not self.init_time:
|
| 78 |
return 0.0
|
| 79 |
return (datetime.now() - self.init_time).total_seconds()
|
| 80 |
-
|
| 81 |
def is_ready(self) -> bool:
|
| 82 |
"""Check if service is ready to handle requests"""
|
| 83 |
return self.initialized and self.guild is not None
|
| 84 |
-
|
| 85 |
def analyze(
|
| 86 |
self,
|
| 87 |
-
biomarkers:
|
| 88 |
-
patient_context:
|
| 89 |
-
model_prediction:
|
| 90 |
-
extracted_biomarkers:
|
| 91 |
) -> AnalysisResponse:
|
| 92 |
"""
|
| 93 |
Run complete analysis workflow and format full detailed response.
|
|
@@ -103,10 +110,10 @@ class RagBotService:
|
|
| 103 |
"""
|
| 104 |
if not self.is_ready():
|
| 105 |
raise RuntimeError("RagBot service not initialized. Call initialize() first.")
|
| 106 |
-
|
| 107 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 108 |
start_time = time.time()
|
| 109 |
-
|
| 110 |
try:
|
| 111 |
# Create PatientInput
|
| 112 |
patient_input = PatientInput(
|
|
@@ -114,13 +121,13 @@ class RagBotService:
|
|
| 114 |
model_prediction=model_prediction,
|
| 115 |
patient_context=patient_context
|
| 116 |
)
|
| 117 |
-
|
| 118 |
# Run workflow
|
| 119 |
workflow_result = self.guild.run(patient_input)
|
| 120 |
-
|
| 121 |
# Calculate processing time
|
| 122 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 123 |
-
|
| 124 |
# Format response
|
| 125 |
response = self._format_response(
|
| 126 |
request_id=request_id,
|
|
@@ -131,21 +138,21 @@ class RagBotService:
|
|
| 131 |
model_prediction=model_prediction,
|
| 132 |
processing_time_ms=processing_time_ms
|
| 133 |
)
|
| 134 |
-
|
| 135 |
return response
|
| 136 |
-
|
| 137 |
except Exception as e:
|
| 138 |
# Re-raise with context
|
| 139 |
-
raise RuntimeError(f"Analysis failed during workflow execution: {
|
| 140 |
-
|
| 141 |
def _format_response(
|
| 142 |
self,
|
| 143 |
request_id: str,
|
| 144 |
-
workflow_result:
|
| 145 |
-
input_biomarkers:
|
| 146 |
-
extracted_biomarkers:
|
| 147 |
-
patient_context:
|
| 148 |
-
model_prediction:
|
| 149 |
processing_time_ms: float
|
| 150 |
) -> AnalysisResponse:
|
| 151 |
"""
|
|
@@ -159,17 +166,17 @@ class RagBotService:
|
|
| 159 |
- safety_alerts: list of SafetyAlert objects
|
| 160 |
- sop_version, processing_timestamp, etc.
|
| 161 |
"""
|
| 162 |
-
|
| 163 |
# The synthesizer output is nested inside final_response
|
| 164 |
final_response = workflow_result.get("final_response", {}) or {}
|
| 165 |
-
|
| 166 |
# Extract main prediction
|
| 167 |
prediction = Prediction(
|
| 168 |
disease=model_prediction["disease"],
|
| 169 |
confidence=model_prediction["confidence"],
|
| 170 |
probabilities=model_prediction.get("probabilities", {})
|
| 171 |
)
|
| 172 |
-
|
| 173 |
# Biomarker flags: prefer state-level data (BiomarkerFlag objects from validator),
|
| 174 |
# fall back to synthesizer output
|
| 175 |
state_flags = workflow_result.get("biomarker_flags", [])
|
|
@@ -188,7 +195,7 @@ class RagBotService:
|
|
| 188 |
BiomarkerFlag(**flag) if isinstance(flag, dict) else BiomarkerFlag(**flag.model_dump())
|
| 189 |
for flag in biomarker_flags_source
|
| 190 |
]
|
| 191 |
-
|
| 192 |
# Safety alerts: prefer state-level data, fall back to synthesizer
|
| 193 |
state_alerts = workflow_result.get("safety_alerts", [])
|
| 194 |
if state_alerts:
|
|
@@ -206,7 +213,7 @@ class RagBotService:
|
|
| 206 |
SafetyAlert(**alert) if isinstance(alert, dict) else SafetyAlert(**alert.model_dump())
|
| 207 |
for alert in safety_alerts_source
|
| 208 |
]
|
| 209 |
-
|
| 210 |
# Extract key drivers from synthesizer output
|
| 211 |
key_drivers_data = final_response.get("key_drivers", [])
|
| 212 |
if not key_drivers_data:
|
|
@@ -215,7 +222,7 @@ class RagBotService:
|
|
| 215 |
for driver in key_drivers_data:
|
| 216 |
if isinstance(driver, dict):
|
| 217 |
key_drivers.append(KeyDriver(**driver))
|
| 218 |
-
|
| 219 |
# Disease explanation from synthesizer
|
| 220 |
disease_exp_data = final_response.get("disease_explanation", {})
|
| 221 |
if not disease_exp_data:
|
|
@@ -225,7 +232,7 @@ class RagBotService:
|
|
| 225 |
citations=disease_exp_data.get("citations", []),
|
| 226 |
retrieved_chunks=disease_exp_data.get("retrieved_chunks")
|
| 227 |
)
|
| 228 |
-
|
| 229 |
# Recommendations from synthesizer
|
| 230 |
recs_data = final_response.get("recommendations", {})
|
| 231 |
if not recs_data:
|
|
@@ -238,7 +245,7 @@ class RagBotService:
|
|
| 238 |
monitoring=recs_data.get("monitoring", []),
|
| 239 |
follow_up=recs_data.get("follow_up")
|
| 240 |
)
|
| 241 |
-
|
| 242 |
# Confidence assessment from synthesizer
|
| 243 |
conf_data = final_response.get("confidence_assessment", {})
|
| 244 |
if not conf_data:
|
|
@@ -249,12 +256,12 @@ class RagBotService:
|
|
| 249 |
limitations=conf_data.get("limitations", []),
|
| 250 |
reasoning=conf_data.get("reasoning")
|
| 251 |
)
|
| 252 |
-
|
| 253 |
# Alternative diagnoses
|
| 254 |
alternative_diagnoses = final_response.get("alternative_diagnoses")
|
| 255 |
if alternative_diagnoses is None:
|
| 256 |
alternative_diagnoses = final_response.get("analysis", {}).get("alternative_diagnoses")
|
| 257 |
-
|
| 258 |
# Assemble complete analysis
|
| 259 |
analysis = Analysis(
|
| 260 |
biomarker_flags=biomarker_flags,
|
|
@@ -265,7 +272,7 @@ class RagBotService:
|
|
| 265 |
confidence_assessment=confidence_assessment,
|
| 266 |
alternative_diagnoses=alternative_diagnoses
|
| 267 |
)
|
| 268 |
-
|
| 269 |
# Agent outputs from state (these are src.state.AgentOutput objects)
|
| 270 |
agent_outputs_data = workflow_result.get("agent_outputs", [])
|
| 271 |
agent_outputs = []
|
|
@@ -274,7 +281,7 @@ class RagBotService:
|
|
| 274 |
agent_outputs.append(AgentOutput(**agent_out.model_dump()))
|
| 275 |
elif isinstance(agent_out, dict):
|
| 276 |
agent_outputs.append(AgentOutput(**agent_out))
|
| 277 |
-
|
| 278 |
# Workflow metadata
|
| 279 |
workflow_metadata = {
|
| 280 |
"sop_version": workflow_result.get("sop_version"),
|
|
@@ -282,12 +289,12 @@ class RagBotService:
|
|
| 282 |
"agents_executed": len(agent_outputs),
|
| 283 |
"workflow_success": True
|
| 284 |
}
|
| 285 |
-
|
| 286 |
# Conversational summary (if available)
|
| 287 |
conversational_summary = final_response.get("conversational_summary")
|
| 288 |
if not conversational_summary:
|
| 289 |
conversational_summary = final_response.get("patient_summary", {}).get("narrative")
|
| 290 |
-
|
| 291 |
# Generate conversational summary if not present
|
| 292 |
if not conversational_summary:
|
| 293 |
conversational_summary = self._generate_conversational_summary(
|
|
@@ -296,7 +303,7 @@ class RagBotService:
|
|
| 296 |
key_drivers=key_drivers,
|
| 297 |
recommendations=recommendations
|
| 298 |
)
|
| 299 |
-
|
| 300 |
# Assemble final response
|
| 301 |
response = AnalysisResponse(
|
| 302 |
status="success",
|
|
@@ -313,9 +320,9 @@ class RagBotService:
|
|
| 313 |
processing_time_ms=processing_time_ms,
|
| 314 |
sop_version=workflow_result.get("sop_version", "Baseline")
|
| 315 |
)
|
| 316 |
-
|
| 317 |
return response
|
| 318 |
-
|
| 319 |
def _generate_conversational_summary(
|
| 320 |
self,
|
| 321 |
prediction: Prediction,
|
|
@@ -324,37 +331,37 @@ class RagBotService:
|
|
| 324 |
recommendations: Recommendations
|
| 325 |
) -> str:
|
| 326 |
"""Generate a simple conversational summary"""
|
| 327 |
-
|
| 328 |
summary_parts = []
|
| 329 |
summary_parts.append("Hi there!\n")
|
| 330 |
summary_parts.append("Based on your biomarkers, I analyzed your results.\n")
|
| 331 |
-
|
| 332 |
# Prediction
|
| 333 |
summary_parts.append(f"\nPrimary Finding: {prediction.disease}")
|
| 334 |
summary_parts.append(f" Confidence: {prediction.confidence:.0%}\n")
|
| 335 |
-
|
| 336 |
# Safety alerts
|
| 337 |
if safety_alerts:
|
| 338 |
summary_parts.append("\nIMPORTANT SAFETY ALERTS:")
|
| 339 |
for alert in safety_alerts[:3]: # Top 3
|
| 340 |
summary_parts.append(f" - {alert.biomarker}: {alert.message}")
|
| 341 |
summary_parts.append(f" Action: {alert.action}")
|
| 342 |
-
|
| 343 |
# Key drivers
|
| 344 |
if key_drivers:
|
| 345 |
summary_parts.append("\nWhy this prediction?")
|
| 346 |
for driver in key_drivers[:3]: # Top 3
|
| 347 |
summary_parts.append(f" - {driver.biomarker} ({driver.value}): {driver.explanation[:100]}...")
|
| 348 |
-
|
| 349 |
# Recommendations
|
| 350 |
if recommendations.immediate_actions:
|
| 351 |
summary_parts.append("\nWhat You Should Do:")
|
| 352 |
for i, action in enumerate(recommendations.immediate_actions[:3], 1):
|
| 353 |
summary_parts.append(f" {i}. {action}")
|
| 354 |
-
|
| 355 |
summary_parts.append("\nImportant: This is an AI-assisted analysis, NOT medical advice.")
|
| 356 |
summary_parts.append(" Please consult a healthcare professional for proper diagnosis and treatment.")
|
| 357 |
-
|
| 358 |
return "\n".join(summary_parts)
|
| 359 |
|
| 360 |
|
|
|
|
| 6 |
import sys
|
| 7 |
import time
|
| 8 |
import uuid
|
|
|
|
|
|
|
| 9 |
from datetime import datetime
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
|
| 13 |
# Ensure project root is in path for src imports
|
| 14 |
_project_root = str(Path(__file__).parent.parent.parent.parent)
|
| 15 |
if _project_root not in sys.path:
|
| 16 |
sys.path.insert(0, _project_root)
|
| 17 |
|
|
|
|
|
|
|
| 18 |
from app.models.schemas import (
|
| 19 |
+
AgentOutput,
|
| 20 |
+
Analysis,
|
| 21 |
+
AnalysisResponse,
|
| 22 |
+
BiomarkerFlag,
|
| 23 |
+
ConfidenceAssessment,
|
| 24 |
+
DiseaseExplanation,
|
| 25 |
+
KeyDriver,
|
| 26 |
+
Prediction,
|
| 27 |
+
Recommendations,
|
| 28 |
+
SafetyAlert,
|
| 29 |
)
|
| 30 |
+
from src.state import PatientInput
|
| 31 |
+
from src.workflow import create_guild
|
| 32 |
|
| 33 |
|
| 34 |
class RagBotService:
|
|
|
|
| 36 |
Service class to manage RagBot workflow lifecycle.
|
| 37 |
Initializes once, then handles multiple analysis requests.
|
| 38 |
"""
|
| 39 |
+
|
| 40 |
def __init__(self):
|
| 41 |
"""Initialize the workflow (loads vector store, models, etc.)"""
|
| 42 |
self.guild = None
|
| 43 |
self.initialized = False
|
| 44 |
self.init_time = None
|
| 45 |
+
|
| 46 |
def initialize(self):
|
| 47 |
"""Initialize the Clinical Insight Guild (expensive operation)"""
|
| 48 |
if self.initialized:
|
| 49 |
return
|
| 50 |
+
|
| 51 |
print("INFO: Initializing RagBot workflow...")
|
| 52 |
start_time = time.time()
|
| 53 |
+
|
| 54 |
import os
|
| 55 |
+
|
| 56 |
try:
|
| 57 |
# Set working directory via environment so vector store paths resolve
|
| 58 |
# without a process-global os.chdir() (which is thread-unsafe).
|
| 59 |
ragbot_root = Path(__file__).parent.parent.parent.parent
|
| 60 |
os.environ["RAGBOT_ROOT"] = str(ragbot_root)
|
| 61 |
print(f"INFO: Project root: {ragbot_root}")
|
| 62 |
+
|
| 63 |
# Temporarily chdir only during initialization (single-threaded at startup)
|
| 64 |
original_dir = os.getcwd()
|
| 65 |
os.chdir(ragbot_root)
|
| 66 |
+
|
| 67 |
self.guild = create_guild()
|
| 68 |
self.initialized = True
|
| 69 |
self.init_time = datetime.now()
|
| 70 |
+
|
| 71 |
elapsed = (time.time() - start_time) * 1000
|
| 72 |
print(f"OK: RagBot initialized successfully ({elapsed:.0f}ms)")
|
| 73 |
+
|
| 74 |
except Exception as e:
|
| 75 |
print(f"ERROR: Failed to initialize RagBot: {e}")
|
| 76 |
raise
|
| 77 |
+
|
| 78 |
finally:
|
| 79 |
# Restore original directory
|
| 80 |
os.chdir(original_dir)
|
| 81 |
+
|
| 82 |
def get_uptime_seconds(self) -> float:
|
| 83 |
"""Get API uptime in seconds"""
|
| 84 |
if not self.init_time:
|
| 85 |
return 0.0
|
| 86 |
return (datetime.now() - self.init_time).total_seconds()
|
| 87 |
+
|
| 88 |
def is_ready(self) -> bool:
|
| 89 |
"""Check if service is ready to handle requests"""
|
| 90 |
return self.initialized and self.guild is not None
|
| 91 |
+
|
| 92 |
def analyze(
|
| 93 |
self,
|
| 94 |
+
biomarkers: dict[str, float],
|
| 95 |
+
patient_context: dict[str, Any],
|
| 96 |
+
model_prediction: dict[str, Any],
|
| 97 |
+
extracted_biomarkers: dict[str, float] = None
|
| 98 |
) -> AnalysisResponse:
|
| 99 |
"""
|
| 100 |
Run complete analysis workflow and format full detailed response.
|
|
|
|
| 110 |
"""
|
| 111 |
if not self.is_ready():
|
| 112 |
raise RuntimeError("RagBot service not initialized. Call initialize() first.")
|
| 113 |
+
|
| 114 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 115 |
start_time = time.time()
|
| 116 |
+
|
| 117 |
try:
|
| 118 |
# Create PatientInput
|
| 119 |
patient_input = PatientInput(
|
|
|
|
| 121 |
model_prediction=model_prediction,
|
| 122 |
patient_context=patient_context
|
| 123 |
)
|
| 124 |
+
|
| 125 |
# Run workflow
|
| 126 |
workflow_result = self.guild.run(patient_input)
|
| 127 |
+
|
| 128 |
# Calculate processing time
|
| 129 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 130 |
+
|
| 131 |
# Format response
|
| 132 |
response = self._format_response(
|
| 133 |
request_id=request_id,
|
|
|
|
| 138 |
model_prediction=model_prediction,
|
| 139 |
processing_time_ms=processing_time_ms
|
| 140 |
)
|
| 141 |
+
|
| 142 |
return response
|
| 143 |
+
|
| 144 |
except Exception as e:
|
| 145 |
# Re-raise with context
|
| 146 |
+
raise RuntimeError(f"Analysis failed during workflow execution: {e!s}") from e
|
| 147 |
+
|
| 148 |
def _format_response(
|
| 149 |
self,
|
| 150 |
request_id: str,
|
| 151 |
+
workflow_result: dict[str, Any],
|
| 152 |
+
input_biomarkers: dict[str, float],
|
| 153 |
+
extracted_biomarkers: dict[str, float],
|
| 154 |
+
patient_context: dict[str, Any],
|
| 155 |
+
model_prediction: dict[str, Any],
|
| 156 |
processing_time_ms: float
|
| 157 |
) -> AnalysisResponse:
|
| 158 |
"""
|
|
|
|
| 166 |
- safety_alerts: list of SafetyAlert objects
|
| 167 |
- sop_version, processing_timestamp, etc.
|
| 168 |
"""
|
| 169 |
+
|
| 170 |
# The synthesizer output is nested inside final_response
|
| 171 |
final_response = workflow_result.get("final_response", {}) or {}
|
| 172 |
+
|
| 173 |
# Extract main prediction
|
| 174 |
prediction = Prediction(
|
| 175 |
disease=model_prediction["disease"],
|
| 176 |
confidence=model_prediction["confidence"],
|
| 177 |
probabilities=model_prediction.get("probabilities", {})
|
| 178 |
)
|
| 179 |
+
|
| 180 |
# Biomarker flags: prefer state-level data (BiomarkerFlag objects from validator),
|
| 181 |
# fall back to synthesizer output
|
| 182 |
state_flags = workflow_result.get("biomarker_flags", [])
|
|
|
|
| 195 |
BiomarkerFlag(**flag) if isinstance(flag, dict) else BiomarkerFlag(**flag.model_dump())
|
| 196 |
for flag in biomarker_flags_source
|
| 197 |
]
|
| 198 |
+
|
| 199 |
# Safety alerts: prefer state-level data, fall back to synthesizer
|
| 200 |
state_alerts = workflow_result.get("safety_alerts", [])
|
| 201 |
if state_alerts:
|
|
|
|
| 213 |
SafetyAlert(**alert) if isinstance(alert, dict) else SafetyAlert(**alert.model_dump())
|
| 214 |
for alert in safety_alerts_source
|
| 215 |
]
|
| 216 |
+
|
| 217 |
# Extract key drivers from synthesizer output
|
| 218 |
key_drivers_data = final_response.get("key_drivers", [])
|
| 219 |
if not key_drivers_data:
|
|
|
|
| 222 |
for driver in key_drivers_data:
|
| 223 |
if isinstance(driver, dict):
|
| 224 |
key_drivers.append(KeyDriver(**driver))
|
| 225 |
+
|
| 226 |
# Disease explanation from synthesizer
|
| 227 |
disease_exp_data = final_response.get("disease_explanation", {})
|
| 228 |
if not disease_exp_data:
|
|
|
|
| 232 |
citations=disease_exp_data.get("citations", []),
|
| 233 |
retrieved_chunks=disease_exp_data.get("retrieved_chunks")
|
| 234 |
)
|
| 235 |
+
|
| 236 |
# Recommendations from synthesizer
|
| 237 |
recs_data = final_response.get("recommendations", {})
|
| 238 |
if not recs_data:
|
|
|
|
| 245 |
monitoring=recs_data.get("monitoring", []),
|
| 246 |
follow_up=recs_data.get("follow_up")
|
| 247 |
)
|
| 248 |
+
|
| 249 |
# Confidence assessment from synthesizer
|
| 250 |
conf_data = final_response.get("confidence_assessment", {})
|
| 251 |
if not conf_data:
|
|
|
|
| 256 |
limitations=conf_data.get("limitations", []),
|
| 257 |
reasoning=conf_data.get("reasoning")
|
| 258 |
)
|
| 259 |
+
|
| 260 |
# Alternative diagnoses
|
| 261 |
alternative_diagnoses = final_response.get("alternative_diagnoses")
|
| 262 |
if alternative_diagnoses is None:
|
| 263 |
alternative_diagnoses = final_response.get("analysis", {}).get("alternative_diagnoses")
|
| 264 |
+
|
| 265 |
# Assemble complete analysis
|
| 266 |
analysis = Analysis(
|
| 267 |
biomarker_flags=biomarker_flags,
|
|
|
|
| 272 |
confidence_assessment=confidence_assessment,
|
| 273 |
alternative_diagnoses=alternative_diagnoses
|
| 274 |
)
|
| 275 |
+
|
| 276 |
# Agent outputs from state (these are src.state.AgentOutput objects)
|
| 277 |
agent_outputs_data = workflow_result.get("agent_outputs", [])
|
| 278 |
agent_outputs = []
|
|
|
|
| 281 |
agent_outputs.append(AgentOutput(**agent_out.model_dump()))
|
| 282 |
elif isinstance(agent_out, dict):
|
| 283 |
agent_outputs.append(AgentOutput(**agent_out))
|
| 284 |
+
|
| 285 |
# Workflow metadata
|
| 286 |
workflow_metadata = {
|
| 287 |
"sop_version": workflow_result.get("sop_version"),
|
|
|
|
| 289 |
"agents_executed": len(agent_outputs),
|
| 290 |
"workflow_success": True
|
| 291 |
}
|
| 292 |
+
|
| 293 |
# Conversational summary (if available)
|
| 294 |
conversational_summary = final_response.get("conversational_summary")
|
| 295 |
if not conversational_summary:
|
| 296 |
conversational_summary = final_response.get("patient_summary", {}).get("narrative")
|
| 297 |
+
|
| 298 |
# Generate conversational summary if not present
|
| 299 |
if not conversational_summary:
|
| 300 |
conversational_summary = self._generate_conversational_summary(
|
|
|
|
| 303 |
key_drivers=key_drivers,
|
| 304 |
recommendations=recommendations
|
| 305 |
)
|
| 306 |
+
|
| 307 |
# Assemble final response
|
| 308 |
response = AnalysisResponse(
|
| 309 |
status="success",
|
|
|
|
| 320 |
processing_time_ms=processing_time_ms,
|
| 321 |
sop_version=workflow_result.get("sop_version", "Baseline")
|
| 322 |
)
|
| 323 |
+
|
| 324 |
return response
|
| 325 |
+
|
| 326 |
def _generate_conversational_summary(
|
| 327 |
self,
|
| 328 |
prediction: Prediction,
|
|
|
|
| 331 |
recommendations: Recommendations
|
| 332 |
) -> str:
|
| 333 |
"""Generate a simple conversational summary"""
|
| 334 |
+
|
| 335 |
summary_parts = []
|
| 336 |
summary_parts.append("Hi there!\n")
|
| 337 |
summary_parts.append("Based on your biomarkers, I analyzed your results.\n")
|
| 338 |
+
|
| 339 |
# Prediction
|
| 340 |
summary_parts.append(f"\nPrimary Finding: {prediction.disease}")
|
| 341 |
summary_parts.append(f" Confidence: {prediction.confidence:.0%}\n")
|
| 342 |
+
|
| 343 |
# Safety alerts
|
| 344 |
if safety_alerts:
|
| 345 |
summary_parts.append("\nIMPORTANT SAFETY ALERTS:")
|
| 346 |
for alert in safety_alerts[:3]: # Top 3
|
| 347 |
summary_parts.append(f" - {alert.biomarker}: {alert.message}")
|
| 348 |
summary_parts.append(f" Action: {alert.action}")
|
| 349 |
+
|
| 350 |
# Key drivers
|
| 351 |
if key_drivers:
|
| 352 |
summary_parts.append("\nWhy this prediction?")
|
| 353 |
for driver in key_drivers[:3]: # Top 3
|
| 354 |
summary_parts.append(f" - {driver.biomarker} ({driver.value}): {driver.explanation[:100]}...")
|
| 355 |
+
|
| 356 |
# Recommendations
|
| 357 |
if recommendations.immediate_actions:
|
| 358 |
summary_parts.append("\nWhat You Should Do:")
|
| 359 |
for i, action in enumerate(recommendations.immediate_actions[:3], 1):
|
| 360 |
summary_parts.append(f" {i}. {action}")
|
| 361 |
+
|
| 362 |
summary_parts.append("\nImportant: This is an AI-assisted analysis, NOT medical advice.")
|
| 363 |
summary_parts.append(" Please consult a healthcare professional for proper diagnosis and treatment.")
|
| 364 |
+
|
| 365 |
return "\n".join(summary_parts)
|
| 366 |
|
| 367 |
|
archive/evolution/__init__.py
CHANGED
|
@@ -4,32 +4,26 @@ Self-improvement system for SOP optimization
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from .director import (
|
| 7 |
-
SOPGenePool,
|
| 8 |
Diagnosis,
|
| 9 |
-
SOPMutation,
|
| 10 |
EvolvedSOPs,
|
|
|
|
|
|
|
| 11 |
performance_diagnostician,
|
|
|
|
| 12 |
sop_architect,
|
| 13 |
-
run_evolution_cycle
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
from .pareto import (
|
| 17 |
-
identify_pareto_front,
|
| 18 |
-
visualize_pareto_frontier,
|
| 19 |
-
print_pareto_summary,
|
| 20 |
-
analyze_improvements
|
| 21 |
)
|
|
|
|
| 22 |
|
| 23 |
__all__ = [
|
| 24 |
-
'SOPGenePool',
|
| 25 |
'Diagnosis',
|
| 26 |
-
'SOPMutation',
|
| 27 |
'EvolvedSOPs',
|
| 28 |
-
'
|
| 29 |
-
'
|
| 30 |
-
'
|
| 31 |
'identify_pareto_front',
|
| 32 |
-
'
|
| 33 |
'print_pareto_summary',
|
| 34 |
-
'
|
|
|
|
|
|
|
| 35 |
]
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from .director import (
|
|
|
|
| 7 |
Diagnosis,
|
|
|
|
| 8 |
EvolvedSOPs,
|
| 9 |
+
SOPGenePool,
|
| 10 |
+
SOPMutation,
|
| 11 |
performance_diagnostician,
|
| 12 |
+
run_evolution_cycle,
|
| 13 |
sop_architect,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
)
|
| 15 |
+
from .pareto import analyze_improvements, identify_pareto_front, print_pareto_summary, visualize_pareto_frontier
|
| 16 |
|
| 17 |
__all__ = [
|
|
|
|
| 18 |
'Diagnosis',
|
|
|
|
| 19 |
'EvolvedSOPs',
|
| 20 |
+
'SOPGenePool',
|
| 21 |
+
'SOPMutation',
|
| 22 |
+
'analyze_improvements',
|
| 23 |
'identify_pareto_front',
|
| 24 |
+
'performance_diagnostician',
|
| 25 |
'print_pareto_summary',
|
| 26 |
+
'run_evolution_cycle',
|
| 27 |
+
'sop_architect',
|
| 28 |
+
'visualize_pareto_frontier'
|
| 29 |
]
|
archive/evolution/director.py
CHANGED
|
@@ -3,27 +3,28 @@ MediGuard AI RAG-Helper - Evolution Engine
|
|
| 3 |
Outer Loop Director for SOP Evolution
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
| 7 |
-
from typing import Any,
|
|
|
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
-
|
| 10 |
from src.config import ExplanationSOP
|
| 11 |
from src.evaluation.evaluators import EvaluationResult
|
| 12 |
|
| 13 |
|
| 14 |
class SOPGenePool:
|
| 15 |
"""Manages version control for evolving SOPs"""
|
| 16 |
-
|
| 17 |
def __init__(self):
|
| 18 |
-
self.pool:
|
| 19 |
-
self.gene_pool:
|
| 20 |
self.version_counter = 0
|
| 21 |
-
|
| 22 |
def add(
|
| 23 |
self,
|
| 24 |
sop: ExplanationSOP,
|
| 25 |
evaluation: EvaluationResult,
|
| 26 |
-
parent_version:
|
| 27 |
description: str = ""
|
| 28 |
):
|
| 29 |
"""Add a new SOP to the gene pool"""
|
|
@@ -38,50 +39,50 @@ class SOPGenePool:
|
|
| 38 |
self.pool.append(entry)
|
| 39 |
self.gene_pool = self.pool # Keep in sync
|
| 40 |
print(f"✓ Added SOP v{self.version_counter} to gene pool: {description}")
|
| 41 |
-
|
| 42 |
-
def get_latest(self) ->
|
| 43 |
"""Get the most recent SOP"""
|
| 44 |
return self.pool[-1] if self.pool else None
|
| 45 |
-
|
| 46 |
-
def get_by_version(self, version: int) ->
|
| 47 |
"""Retrieve specific SOP version"""
|
| 48 |
for entry in self.pool:
|
| 49 |
if entry['version'] == version:
|
| 50 |
return entry
|
| 51 |
return None
|
| 52 |
-
|
| 53 |
-
def get_best_by_metric(self, metric: str) ->
|
| 54 |
"""Get SOP with highest score on specific metric"""
|
| 55 |
if not self.pool:
|
| 56 |
return None
|
| 57 |
-
|
| 58 |
best = max(
|
| 59 |
self.pool,
|
| 60 |
key=lambda x: getattr(x['evaluation'], metric).score
|
| 61 |
)
|
| 62 |
return best
|
| 63 |
-
|
| 64 |
def summary(self):
|
| 65 |
"""Print summary of all SOPs in pool"""
|
| 66 |
print("\n" + "=" * 80)
|
| 67 |
print("SOP GENE POOL SUMMARY")
|
| 68 |
print("=" * 80)
|
| 69 |
-
|
| 70 |
for entry in self.pool:
|
| 71 |
v = entry['version']
|
| 72 |
p = entry['parent']
|
| 73 |
desc = entry['description']
|
| 74 |
e = entry['evaluation']
|
| 75 |
-
|
| 76 |
parent_str = "(Baseline)" if p is None else f"(Child of v{p})"
|
| 77 |
-
|
| 78 |
print(f"\nSOP v{v} {parent_str}: {desc}")
|
| 79 |
print(f" Clinical Accuracy: {e.clinical_accuracy.score:.2f}")
|
| 80 |
print(f" Evidence Grounding: {e.evidence_grounding.score:.2f}")
|
| 81 |
print(f" Actionability: {e.actionability.score:.2f}")
|
| 82 |
print(f" Clarity: {e.clarity.score:.2f}")
|
| 83 |
print(f" Safety & Completeness: {e.safety_completeness.score:.2f}")
|
| 84 |
-
|
| 85 |
print("\n" + "=" * 80)
|
| 86 |
|
| 87 |
|
|
@@ -120,7 +121,7 @@ class SOPMutation(BaseModel):
|
|
| 120 |
|
| 121 |
class EvolvedSOPs(BaseModel):
|
| 122 |
"""Container for mutated SOPs from Architect"""
|
| 123 |
-
mutations:
|
| 124 |
|
| 125 |
|
| 126 |
def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
@@ -131,7 +132,7 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 131 |
print("\n" + "=" * 70)
|
| 132 |
print("EXECUTING: Performance Diagnostician")
|
| 133 |
print("=" * 70)
|
| 134 |
-
|
| 135 |
# Find lowest score programmatically (no LLM needed)
|
| 136 |
scores = {
|
| 137 |
'clinical_accuracy': evaluation.clinical_accuracy.score,
|
|
@@ -140,7 +141,7 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 140 |
'clarity': evaluation.clarity.score,
|
| 141 |
'safety_completeness': evaluation.safety_completeness.score
|
| 142 |
}
|
| 143 |
-
|
| 144 |
reasonings = {
|
| 145 |
'clinical_accuracy': evaluation.clinical_accuracy.reasoning,
|
| 146 |
'evidence_grounding': evaluation.evidence_grounding.reasoning,
|
|
@@ -148,11 +149,11 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 148 |
'clarity': evaluation.clarity.reasoning,
|
| 149 |
'safety_completeness': evaluation.safety_completeness.reasoning
|
| 150 |
}
|
| 151 |
-
|
| 152 |
primary_weakness = min(scores, key=scores.get)
|
| 153 |
weakness_score = scores[primary_weakness]
|
| 154 |
weakness_reasoning = reasonings[primary_weakness]
|
| 155 |
-
|
| 156 |
# Generate detailed root cause analysis
|
| 157 |
root_cause_map = {
|
| 158 |
'clinical_accuracy': f"Clinical accuracy score ({weakness_score:.2f}) indicates potential issues with medical interpretations. {weakness_reasoning[:200]}",
|
|
@@ -161,7 +162,7 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 161 |
'clarity': f"Clarity score ({weakness_score:.2f}) suggests readability issues. {weakness_reasoning[:200]}",
|
| 162 |
'safety_completeness': f"Safety score ({weakness_score:.2f}) indicates missing risk discussions. {weakness_reasoning[:200]}"
|
| 163 |
}
|
| 164 |
-
|
| 165 |
recommendation_map = {
|
| 166 |
'clinical_accuracy': "Increase RAG depth to access more authoritative medical sources.",
|
| 167 |
'evidence_grounding': "Enforce strict citation requirements and increase RAG depth.",
|
|
@@ -169,17 +170,17 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
| 169 |
'clarity': "Simplify language and reduce technical jargon for better readability.",
|
| 170 |
'safety_completeness': "Add explicit safety warnings and ensure complete risk coverage."
|
| 171 |
}
|
| 172 |
-
|
| 173 |
diagnosis = Diagnosis(
|
| 174 |
primary_weakness=primary_weakness,
|
| 175 |
root_cause_analysis=root_cause_map[primary_weakness],
|
| 176 |
recommendation=recommendation_map[primary_weakness]
|
| 177 |
)
|
| 178 |
-
|
| 179 |
-
print(
|
| 180 |
print(f" Primary weakness: {diagnosis.primary_weakness} ({weakness_score:.3f})")
|
| 181 |
print(f" Recommendation: {diagnosis.recommendation}")
|
| 182 |
-
|
| 183 |
return diagnosis
|
| 184 |
|
| 185 |
|
|
@@ -195,9 +196,9 @@ def sop_architect(
|
|
| 195 |
print("EXECUTING: SOP Architect")
|
| 196 |
print("=" * 70)
|
| 197 |
print(f"Target weakness: {diagnosis.primary_weakness}")
|
| 198 |
-
|
| 199 |
weakness = diagnosis.primary_weakness
|
| 200 |
-
|
| 201 |
# Generate mutations based on weakness type
|
| 202 |
if weakness == 'clarity':
|
| 203 |
mut1 = SOPMutation(
|
|
@@ -226,7 +227,7 @@ def sop_architect(
|
|
| 226 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 227 |
description="Balanced detail with fewer citations for readability"
|
| 228 |
)
|
| 229 |
-
|
| 230 |
elif weakness == 'evidence_grounding':
|
| 231 |
mut1 = SOPMutation(
|
| 232 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
|
@@ -254,7 +255,7 @@ def sop_architect(
|
|
| 254 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 255 |
description="Moderate RAG increase with citation enforcement"
|
| 256 |
)
|
| 257 |
-
|
| 258 |
elif weakness == 'actionability':
|
| 259 |
mut1 = SOPMutation(
|
| 260 |
disease_explainer_k=current_sop.disease_explainer_k,
|
|
@@ -282,7 +283,7 @@ def sop_architect(
|
|
| 282 |
critical_value_alert_mode='strict',
|
| 283 |
description="Comprehensive approach with all agents enabled"
|
| 284 |
)
|
| 285 |
-
|
| 286 |
elif weakness == 'clinical_accuracy':
|
| 287 |
mut1 = SOPMutation(
|
| 288 |
disease_explainer_k=10,
|
|
@@ -310,7 +311,7 @@ def sop_architect(
|
|
| 310 |
critical_value_alert_mode='strict',
|
| 311 |
description="High RAG depth with comprehensive detail"
|
| 312 |
)
|
| 313 |
-
|
| 314 |
else: # safety_completeness
|
| 315 |
mut1 = SOPMutation(
|
| 316 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
|
@@ -338,14 +339,14 @@ def sop_architect(
|
|
| 338 |
critical_value_alert_mode='strict',
|
| 339 |
description="Maximum coverage with all safety features"
|
| 340 |
)
|
| 341 |
-
|
| 342 |
evolved = EvolvedSOPs(mutations=[mut1, mut2])
|
| 343 |
-
|
| 344 |
print(f"\n✓ Generated {len(evolved.mutations)} mutations")
|
| 345 |
for i, mut in enumerate(evolved.mutations, 1):
|
| 346 |
print(f" {i}. {mut.description}")
|
| 347 |
print(f" Disease K: {mut.disease_explainer_k}, Detail: {mut.explainer_detail_level}")
|
| 348 |
-
|
| 349 |
return evolved
|
| 350 |
|
| 351 |
|
|
@@ -354,7 +355,7 @@ def run_evolution_cycle(
|
|
| 354 |
patient_input: Any,
|
| 355 |
workflow_graph: Any,
|
| 356 |
evaluation_func: Callable
|
| 357 |
-
) ->
|
| 358 |
"""
|
| 359 |
Executes one complete evolution cycle:
|
| 360 |
1. Diagnose current best SOP
|
|
@@ -367,38 +368,37 @@ def run_evolution_cycle(
|
|
| 367 |
print("\n" + "=" * 80)
|
| 368 |
print("STARTING EVOLUTION CYCLE")
|
| 369 |
print("=" * 80)
|
| 370 |
-
|
| 371 |
# Get current best (for simplicity, use latest)
|
| 372 |
current_best = gene_pool.get_latest()
|
| 373 |
if not current_best:
|
| 374 |
raise ValueError("Gene pool is empty. Add baseline SOP first.")
|
| 375 |
-
|
| 376 |
parent_sop = current_best['sop']
|
| 377 |
parent_eval = current_best['evaluation']
|
| 378 |
parent_version = current_best['version']
|
| 379 |
-
|
| 380 |
print(f"\nImproving upon SOP v{parent_version}")
|
| 381 |
-
|
| 382 |
# Step 1: Diagnose
|
| 383 |
diagnosis = performance_diagnostician(parent_eval)
|
| 384 |
-
|
| 385 |
# Step 2: Generate mutations
|
| 386 |
evolved_sops = sop_architect(diagnosis, parent_sop)
|
| 387 |
-
|
| 388 |
# Step 3: Test each mutation
|
| 389 |
new_entries = []
|
| 390 |
for i, mutant_sop_model in enumerate(evolved_sops.mutations, 1):
|
| 391 |
print(f"\n{'=' * 70}")
|
| 392 |
print(f"TESTING MUTATION {i}/{len(evolved_sops.mutations)}: {mutant_sop_model.description}")
|
| 393 |
print("=" * 70)
|
| 394 |
-
|
| 395 |
# Convert SOPMutation to ExplanationSOP
|
| 396 |
mutant_sop_dict = mutant_sop_model.model_dump()
|
| 397 |
description = mutant_sop_dict.pop('description')
|
| 398 |
mutant_sop = ExplanationSOP(**mutant_sop_dict)
|
| 399 |
-
|
| 400 |
# Run workflow with mutated SOP
|
| 401 |
-
from src.state import PatientInput
|
| 402 |
from datetime import datetime
|
| 403 |
graph_input = {
|
| 404 |
"patient_biomarkers": patient_input.biomarkers,
|
|
@@ -414,17 +414,17 @@ def run_evolution_cycle(
|
|
| 414 |
"processing_timestamp": datetime.now().isoformat(),
|
| 415 |
"sop_version": description
|
| 416 |
}
|
| 417 |
-
|
| 418 |
try:
|
| 419 |
final_state = workflow_graph.invoke(graph_input)
|
| 420 |
-
|
| 421 |
# Evaluate output
|
| 422 |
evaluation = evaluation_func(
|
| 423 |
final_response=final_state['final_response'],
|
| 424 |
agent_outputs=final_state['agent_outputs'],
|
| 425 |
biomarkers=patient_input.biomarkers
|
| 426 |
)
|
| 427 |
-
|
| 428 |
# Add to gene pool
|
| 429 |
gene_pool.add(
|
| 430 |
sop=mutant_sop,
|
|
@@ -432,7 +432,7 @@ def run_evolution_cycle(
|
|
| 432 |
parent_version=parent_version,
|
| 433 |
description=description
|
| 434 |
)
|
| 435 |
-
|
| 436 |
new_entries.append({
|
| 437 |
"sop": mutant_sop,
|
| 438 |
"evaluation": evaluation,
|
|
@@ -441,9 +441,9 @@ def run_evolution_cycle(
|
|
| 441 |
except Exception as e:
|
| 442 |
print(f"❌ Mutation {i} failed: {e}")
|
| 443 |
continue
|
| 444 |
-
|
| 445 |
print("\n" + "=" * 80)
|
| 446 |
print("EVOLUTION CYCLE COMPLETE")
|
| 447 |
print("=" * 80)
|
| 448 |
-
|
| 449 |
return new_entries
|
|
|
|
| 3 |
Outer Loop Director for SOP Evolution
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
from collections.abc import Callable
|
| 7 |
+
from typing import Any, Literal
|
| 8 |
+
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
from src.config import ExplanationSOP
|
| 12 |
from src.evaluation.evaluators import EvaluationResult
|
| 13 |
|
| 14 |
|
| 15 |
class SOPGenePool:
|
| 16 |
"""Manages version control for evolving SOPs"""
|
| 17 |
+
|
| 18 |
def __init__(self):
|
| 19 |
+
self.pool: list[dict[str, Any]] = []
|
| 20 |
+
self.gene_pool: list[dict[str, Any]] = [] # Alias for compatibility
|
| 21 |
self.version_counter = 0
|
| 22 |
+
|
| 23 |
def add(
|
| 24 |
self,
|
| 25 |
sop: ExplanationSOP,
|
| 26 |
evaluation: EvaluationResult,
|
| 27 |
+
parent_version: int | None = None,
|
| 28 |
description: str = ""
|
| 29 |
):
|
| 30 |
"""Add a new SOP to the gene pool"""
|
|
|
|
| 39 |
self.pool.append(entry)
|
| 40 |
self.gene_pool = self.pool # Keep in sync
|
| 41 |
print(f"✓ Added SOP v{self.version_counter} to gene pool: {description}")
|
| 42 |
+
|
| 43 |
+
def get_latest(self) -> dict[str, Any] | None:
|
| 44 |
"""Get the most recent SOP"""
|
| 45 |
return self.pool[-1] if self.pool else None
|
| 46 |
+
|
| 47 |
+
def get_by_version(self, version: int) -> dict[str, Any] | None:
|
| 48 |
"""Retrieve specific SOP version"""
|
| 49 |
for entry in self.pool:
|
| 50 |
if entry['version'] == version:
|
| 51 |
return entry
|
| 52 |
return None
|
| 53 |
+
|
| 54 |
+
def get_best_by_metric(self, metric: str) -> dict[str, Any] | None:
|
| 55 |
"""Get SOP with highest score on specific metric"""
|
| 56 |
if not self.pool:
|
| 57 |
return None
|
| 58 |
+
|
| 59 |
best = max(
|
| 60 |
self.pool,
|
| 61 |
key=lambda x: getattr(x['evaluation'], metric).score
|
| 62 |
)
|
| 63 |
return best
|
| 64 |
+
|
| 65 |
def summary(self):
|
| 66 |
"""Print summary of all SOPs in pool"""
|
| 67 |
print("\n" + "=" * 80)
|
| 68 |
print("SOP GENE POOL SUMMARY")
|
| 69 |
print("=" * 80)
|
| 70 |
+
|
| 71 |
for entry in self.pool:
|
| 72 |
v = entry['version']
|
| 73 |
p = entry['parent']
|
| 74 |
desc = entry['description']
|
| 75 |
e = entry['evaluation']
|
| 76 |
+
|
| 77 |
parent_str = "(Baseline)" if p is None else f"(Child of v{p})"
|
| 78 |
+
|
| 79 |
print(f"\nSOP v{v} {parent_str}: {desc}")
|
| 80 |
print(f" Clinical Accuracy: {e.clinical_accuracy.score:.2f}")
|
| 81 |
print(f" Evidence Grounding: {e.evidence_grounding.score:.2f}")
|
| 82 |
print(f" Actionability: {e.actionability.score:.2f}")
|
| 83 |
print(f" Clarity: {e.clarity.score:.2f}")
|
| 84 |
print(f" Safety & Completeness: {e.safety_completeness.score:.2f}")
|
| 85 |
+
|
| 86 |
print("\n" + "=" * 80)
|
| 87 |
|
| 88 |
|
|
|
|
| 121 |
|
| 122 |
class EvolvedSOPs(BaseModel):
|
| 123 |
"""Container for mutated SOPs from Architect"""
|
| 124 |
+
mutations: list[SOPMutation]
|
| 125 |
|
| 126 |
|
| 127 |
def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
|
|
|
|
| 132 |
print("\n" + "=" * 70)
|
| 133 |
print("EXECUTING: Performance Diagnostician")
|
| 134 |
print("=" * 70)
|
| 135 |
+
|
| 136 |
# Find lowest score programmatically (no LLM needed)
|
| 137 |
scores = {
|
| 138 |
'clinical_accuracy': evaluation.clinical_accuracy.score,
|
|
|
|
| 141 |
'clarity': evaluation.clarity.score,
|
| 142 |
'safety_completeness': evaluation.safety_completeness.score
|
| 143 |
}
|
| 144 |
+
|
| 145 |
reasonings = {
|
| 146 |
'clinical_accuracy': evaluation.clinical_accuracy.reasoning,
|
| 147 |
'evidence_grounding': evaluation.evidence_grounding.reasoning,
|
|
|
|
| 149 |
'clarity': evaluation.clarity.reasoning,
|
| 150 |
'safety_completeness': evaluation.safety_completeness.reasoning
|
| 151 |
}
|
| 152 |
+
|
| 153 |
primary_weakness = min(scores, key=scores.get)
|
| 154 |
weakness_score = scores[primary_weakness]
|
| 155 |
weakness_reasoning = reasonings[primary_weakness]
|
| 156 |
+
|
| 157 |
# Generate detailed root cause analysis
|
| 158 |
root_cause_map = {
|
| 159 |
'clinical_accuracy': f"Clinical accuracy score ({weakness_score:.2f}) indicates potential issues with medical interpretations. {weakness_reasoning[:200]}",
|
|
|
|
| 162 |
'clarity': f"Clarity score ({weakness_score:.2f}) suggests readability issues. {weakness_reasoning[:200]}",
|
| 163 |
'safety_completeness': f"Safety score ({weakness_score:.2f}) indicates missing risk discussions. {weakness_reasoning[:200]}"
|
| 164 |
}
|
| 165 |
+
|
| 166 |
recommendation_map = {
|
| 167 |
'clinical_accuracy': "Increase RAG depth to access more authoritative medical sources.",
|
| 168 |
'evidence_grounding': "Enforce strict citation requirements and increase RAG depth.",
|
|
|
|
| 170 |
'clarity': "Simplify language and reduce technical jargon for better readability.",
|
| 171 |
'safety_completeness': "Add explicit safety warnings and ensure complete risk coverage."
|
| 172 |
}
|
| 173 |
+
|
| 174 |
diagnosis = Diagnosis(
|
| 175 |
primary_weakness=primary_weakness,
|
| 176 |
root_cause_analysis=root_cause_map[primary_weakness],
|
| 177 |
recommendation=recommendation_map[primary_weakness]
|
| 178 |
)
|
| 179 |
+
|
| 180 |
+
print("\n✓ Diagnosis complete")
|
| 181 |
print(f" Primary weakness: {diagnosis.primary_weakness} ({weakness_score:.3f})")
|
| 182 |
print(f" Recommendation: {diagnosis.recommendation}")
|
| 183 |
+
|
| 184 |
return diagnosis
|
| 185 |
|
| 186 |
|
|
|
|
| 196 |
print("EXECUTING: SOP Architect")
|
| 197 |
print("=" * 70)
|
| 198 |
print(f"Target weakness: {diagnosis.primary_weakness}")
|
| 199 |
+
|
| 200 |
weakness = diagnosis.primary_weakness
|
| 201 |
+
|
| 202 |
# Generate mutations based on weakness type
|
| 203 |
if weakness == 'clarity':
|
| 204 |
mut1 = SOPMutation(
|
|
|
|
| 227 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 228 |
description="Balanced detail with fewer citations for readability"
|
| 229 |
)
|
| 230 |
+
|
| 231 |
elif weakness == 'evidence_grounding':
|
| 232 |
mut1 = SOPMutation(
|
| 233 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
|
|
|
|
| 255 |
critical_value_alert_mode=current_sop.critical_value_alert_mode,
|
| 256 |
description="Moderate RAG increase with citation enforcement"
|
| 257 |
)
|
| 258 |
+
|
| 259 |
elif weakness == 'actionability':
|
| 260 |
mut1 = SOPMutation(
|
| 261 |
disease_explainer_k=current_sop.disease_explainer_k,
|
|
|
|
| 283 |
critical_value_alert_mode='strict',
|
| 284 |
description="Comprehensive approach with all agents enabled"
|
| 285 |
)
|
| 286 |
+
|
| 287 |
elif weakness == 'clinical_accuracy':
|
| 288 |
mut1 = SOPMutation(
|
| 289 |
disease_explainer_k=10,
|
|
|
|
| 311 |
critical_value_alert_mode='strict',
|
| 312 |
description="High RAG depth with comprehensive detail"
|
| 313 |
)
|
| 314 |
+
|
| 315 |
else: # safety_completeness
|
| 316 |
mut1 = SOPMutation(
|
| 317 |
disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
|
|
|
|
| 339 |
critical_value_alert_mode='strict',
|
| 340 |
description="Maximum coverage with all safety features"
|
| 341 |
)
|
| 342 |
+
|
| 343 |
evolved = EvolvedSOPs(mutations=[mut1, mut2])
|
| 344 |
+
|
| 345 |
print(f"\n✓ Generated {len(evolved.mutations)} mutations")
|
| 346 |
for i, mut in enumerate(evolved.mutations, 1):
|
| 347 |
print(f" {i}. {mut.description}")
|
| 348 |
print(f" Disease K: {mut.disease_explainer_k}, Detail: {mut.explainer_detail_level}")
|
| 349 |
+
|
| 350 |
return evolved
|
| 351 |
|
| 352 |
|
|
|
|
| 355 |
patient_input: Any,
|
| 356 |
workflow_graph: Any,
|
| 357 |
evaluation_func: Callable
|
| 358 |
+
) -> list[dict[str, Any]]:
|
| 359 |
"""
|
| 360 |
Executes one complete evolution cycle:
|
| 361 |
1. Diagnose current best SOP
|
|
|
|
| 368 |
print("\n" + "=" * 80)
|
| 369 |
print("STARTING EVOLUTION CYCLE")
|
| 370 |
print("=" * 80)
|
| 371 |
+
|
| 372 |
# Get current best (for simplicity, use latest)
|
| 373 |
current_best = gene_pool.get_latest()
|
| 374 |
if not current_best:
|
| 375 |
raise ValueError("Gene pool is empty. Add baseline SOP first.")
|
| 376 |
+
|
| 377 |
parent_sop = current_best['sop']
|
| 378 |
parent_eval = current_best['evaluation']
|
| 379 |
parent_version = current_best['version']
|
| 380 |
+
|
| 381 |
print(f"\nImproving upon SOP v{parent_version}")
|
| 382 |
+
|
| 383 |
# Step 1: Diagnose
|
| 384 |
diagnosis = performance_diagnostician(parent_eval)
|
| 385 |
+
|
| 386 |
# Step 2: Generate mutations
|
| 387 |
evolved_sops = sop_architect(diagnosis, parent_sop)
|
| 388 |
+
|
| 389 |
# Step 3: Test each mutation
|
| 390 |
new_entries = []
|
| 391 |
for i, mutant_sop_model in enumerate(evolved_sops.mutations, 1):
|
| 392 |
print(f"\n{'=' * 70}")
|
| 393 |
print(f"TESTING MUTATION {i}/{len(evolved_sops.mutations)}: {mutant_sop_model.description}")
|
| 394 |
print("=" * 70)
|
| 395 |
+
|
| 396 |
# Convert SOPMutation to ExplanationSOP
|
| 397 |
mutant_sop_dict = mutant_sop_model.model_dump()
|
| 398 |
description = mutant_sop_dict.pop('description')
|
| 399 |
mutant_sop = ExplanationSOP(**mutant_sop_dict)
|
| 400 |
+
|
| 401 |
# Run workflow with mutated SOP
|
|
|
|
| 402 |
from datetime import datetime
|
| 403 |
graph_input = {
|
| 404 |
"patient_biomarkers": patient_input.biomarkers,
|
|
|
|
| 414 |
"processing_timestamp": datetime.now().isoformat(),
|
| 415 |
"sop_version": description
|
| 416 |
}
|
| 417 |
+
|
| 418 |
try:
|
| 419 |
final_state = workflow_graph.invoke(graph_input)
|
| 420 |
+
|
| 421 |
# Evaluate output
|
| 422 |
evaluation = evaluation_func(
|
| 423 |
final_response=final_state['final_response'],
|
| 424 |
agent_outputs=final_state['agent_outputs'],
|
| 425 |
biomarkers=patient_input.biomarkers
|
| 426 |
)
|
| 427 |
+
|
| 428 |
# Add to gene pool
|
| 429 |
gene_pool.add(
|
| 430 |
sop=mutant_sop,
|
|
|
|
| 432 |
parent_version=parent_version,
|
| 433 |
description=description
|
| 434 |
)
|
| 435 |
+
|
| 436 |
new_entries.append({
|
| 437 |
"sop": mutant_sop,
|
| 438 |
"evaluation": evaluation,
|
|
|
|
| 441 |
except Exception as e:
|
| 442 |
print(f"❌ Mutation {i} failed: {e}")
|
| 443 |
continue
|
| 444 |
+
|
| 445 |
print("\n" + "=" * 80)
|
| 446 |
print("EVOLUTION CYCLE COMPLETE")
|
| 447 |
print("=" * 80)
|
| 448 |
+
|
| 449 |
return new_entries
|
archive/evolution/pareto.py
CHANGED
|
@@ -3,14 +3,16 @@ Pareto Frontier Analysis
|
|
| 3 |
Identifies optimal trade-offs in multi-objective optimization
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
import matplotlib
|
|
|
|
|
|
|
| 9 |
matplotlib.use('Agg') # Use non-interactive backend
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
|
| 12 |
|
| 13 |
-
def identify_pareto_front(gene_pool_entries:
|
| 14 |
"""
|
| 15 |
Identifies non-dominated solutions (Pareto Frontier).
|
| 16 |
|
|
@@ -19,32 +21,32 @@ def identify_pareto_front(gene_pool_entries: List[Dict[str, Any]]) -> List[Dict[
|
|
| 19 |
- Strictly better on AT LEAST ONE metric
|
| 20 |
"""
|
| 21 |
pareto_front = []
|
| 22 |
-
|
| 23 |
for i, candidate in enumerate(gene_pool_entries):
|
| 24 |
is_dominated = False
|
| 25 |
-
|
| 26 |
# Get candidate's 5D score vector
|
| 27 |
cand_scores = np.array(candidate['evaluation'].to_vector())
|
| 28 |
-
|
| 29 |
for j, other in enumerate(gene_pool_entries):
|
| 30 |
if i == j:
|
| 31 |
continue
|
| 32 |
-
|
| 33 |
# Get other solution's 5D vector
|
| 34 |
other_scores = np.array(other['evaluation'].to_vector())
|
| 35 |
-
|
| 36 |
# Check domination: other >= candidate on ALL, other > candidate on SOME
|
| 37 |
if np.all(other_scores >= cand_scores) and np.any(other_scores > cand_scores):
|
| 38 |
is_dominated = True
|
| 39 |
break
|
| 40 |
-
|
| 41 |
if not is_dominated:
|
| 42 |
pareto_front.append(candidate)
|
| 43 |
-
|
| 44 |
return pareto_front
|
| 45 |
|
| 46 |
|
| 47 |
-
def visualize_pareto_frontier(pareto_front:
|
| 48 |
"""
|
| 49 |
Creates two visualizations:
|
| 50 |
1. Parallel coordinates plot (5D)
|
|
@@ -53,16 +55,16 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
|
|
| 53 |
if not pareto_front:
|
| 54 |
print("No solutions on Pareto front to visualize")
|
| 55 |
return
|
| 56 |
-
|
| 57 |
fig = plt.figure(figsize=(18, 7))
|
| 58 |
-
|
| 59 |
# --- Plot 1: Bar Chart (since pandas might not be available) ---
|
| 60 |
ax1 = plt.subplot(1, 2, 1)
|
| 61 |
-
|
| 62 |
metrics = ['Clinical\nAccuracy', 'Evidence\nGrounding', 'Actionability', 'Clarity', 'Safety']
|
| 63 |
x = np.arange(len(metrics))
|
| 64 |
width = 0.8 / len(pareto_front)
|
| 65 |
-
|
| 66 |
for idx, entry in enumerate(pareto_front):
|
| 67 |
e = entry['evaluation']
|
| 68 |
scores = [
|
|
@@ -72,11 +74,11 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
|
|
| 72 |
e.clarity.score,
|
| 73 |
e.safety_completeness.score
|
| 74 |
]
|
| 75 |
-
|
| 76 |
offset = (idx - len(pareto_front) / 2) * width + width / 2
|
| 77 |
label = f"SOP v{entry['version']}"
|
| 78 |
ax1.bar(x + offset, scores, width, label=label, alpha=0.8)
|
| 79 |
-
|
| 80 |
ax1.set_xlabel('Metrics', fontsize=12)
|
| 81 |
ax1.set_ylabel('Score', fontsize=12)
|
| 82 |
ax1.set_title('5D Performance Comparison (Bar Chart)', fontsize=14)
|
|
@@ -85,17 +87,17 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
|
|
| 85 |
ax1.set_ylim(0, 1.0)
|
| 86 |
ax1.legend(loc='upper left')
|
| 87 |
ax1.grid(True, alpha=0.3, axis='y')
|
| 88 |
-
|
| 89 |
# --- Plot 2: Radar Chart ---
|
| 90 |
ax2 = plt.subplot(1, 2, 2, projection='polar')
|
| 91 |
-
|
| 92 |
-
categories = ['Clinical\nAccuracy', 'Evidence\nGrounding',
|
| 93 |
'Actionability', 'Clarity', 'Safety']
|
| 94 |
num_vars = len(categories)
|
| 95 |
-
|
| 96 |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
| 97 |
angles += angles[:1]
|
| 98 |
-
|
| 99 |
for entry in pareto_front:
|
| 100 |
e = entry['evaluation']
|
| 101 |
values = [
|
|
@@ -106,47 +108,47 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
|
|
| 106 |
e.safety_completeness.score
|
| 107 |
]
|
| 108 |
values += values[:1]
|
| 109 |
-
|
| 110 |
desc = entry.get('description', '')[:30]
|
| 111 |
label = f"SOP v{entry['version']}: {desc}"
|
| 112 |
ax2.plot(angles, values, 'o-', linewidth=2, label=label)
|
| 113 |
ax2.fill(angles, values, alpha=0.15)
|
| 114 |
-
|
| 115 |
ax2.set_xticks(angles[:-1])
|
| 116 |
ax2.set_xticklabels(categories, size=10)
|
| 117 |
ax2.set_ylim(0, 1)
|
| 118 |
ax2.set_title('5D Performance Profiles (Radar Chart)', size=14, y=1.08)
|
| 119 |
ax2.legend(loc='upper left', bbox_to_anchor=(1.2, 1.0), fontsize=9)
|
| 120 |
ax2.grid(True)
|
| 121 |
-
|
| 122 |
plt.tight_layout()
|
| 123 |
-
|
| 124 |
# Create data directory if it doesn't exist
|
| 125 |
from pathlib import Path
|
| 126 |
data_dir = Path('data')
|
| 127 |
data_dir.mkdir(exist_ok=True)
|
| 128 |
-
|
| 129 |
output_path = data_dir / 'pareto_frontier_analysis.png'
|
| 130 |
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 131 |
plt.close()
|
| 132 |
-
|
| 133 |
print(f"\n✓ Visualization saved to: {output_path}")
|
| 134 |
|
| 135 |
|
| 136 |
-
def print_pareto_summary(pareto_front:
|
| 137 |
"""Print human-readable summary of Pareto frontier"""
|
| 138 |
print("\n" + "=" * 80)
|
| 139 |
print("PARETO FRONTIER ANALYSIS")
|
| 140 |
print("=" * 80)
|
| 141 |
-
|
| 142 |
print(f"\nFound {len(pareto_front)} optimal (non-dominated) solutions:\n")
|
| 143 |
-
|
| 144 |
for entry in pareto_front:
|
| 145 |
v = entry['version']
|
| 146 |
p = entry.get('parent')
|
| 147 |
desc = entry.get('description', 'Baseline')
|
| 148 |
e = entry['evaluation']
|
| 149 |
-
|
| 150 |
print(f"SOP v{v} {f'(Child of v{p})' if p else '(Baseline)'}")
|
| 151 |
print(f" Description: {desc}")
|
| 152 |
print(f" Clinical Accuracy: {e.clinical_accuracy.score:.3f}")
|
|
@@ -154,12 +156,12 @@ def print_pareto_summary(pareto_front: List[Dict[str, Any]]):
|
|
| 154 |
print(f" Actionability: {e.actionability.score:.3f}")
|
| 155 |
print(f" Clarity: {e.clarity.score:.3f}")
|
| 156 |
print(f" Safety & Completeness: {e.safety_completeness.score:.3f}")
|
| 157 |
-
|
| 158 |
# Calculate average
|
| 159 |
avg_score = np.mean(e.to_vector())
|
| 160 |
print(f" Average Score: {avg_score:.3f}")
|
| 161 |
print()
|
| 162 |
-
|
| 163 |
print("=" * 80)
|
| 164 |
print("\nRECOMMENDATION:")
|
| 165 |
print("Review the visualizations and choose the SOP that best matches")
|
|
@@ -167,46 +169,46 @@ def print_pareto_summary(pareto_front: List[Dict[str, Any]]):
|
|
| 167 |
print("=" * 80)
|
| 168 |
|
| 169 |
|
| 170 |
-
def analyze_improvements(gene_pool_entries:
|
| 171 |
"""Analyze improvements over baseline"""
|
| 172 |
if len(gene_pool_entries) < 2:
|
| 173 |
print("\n⚠️ Not enough SOPs to analyze improvements")
|
| 174 |
return
|
| 175 |
-
|
| 176 |
baseline = gene_pool_entries[0]
|
| 177 |
baseline_scores = np.array(baseline['evaluation'].to_vector())
|
| 178 |
-
|
| 179 |
print("\n" + "=" * 80)
|
| 180 |
print("IMPROVEMENT ANALYSIS")
|
| 181 |
print("=" * 80)
|
| 182 |
-
|
| 183 |
print(f"\nBaseline (v{baseline['version']}): {baseline.get('description', 'Initial')}")
|
| 184 |
print(f" Average Score: {np.mean(baseline_scores):.3f}")
|
| 185 |
-
|
| 186 |
improvements_found = False
|
| 187 |
for entry in gene_pool_entries[1:]:
|
| 188 |
scores = np.array(entry['evaluation'].to_vector())
|
| 189 |
avg_score = np.mean(scores)
|
| 190 |
baseline_avg = np.mean(baseline_scores)
|
| 191 |
-
|
| 192 |
if avg_score > baseline_avg:
|
| 193 |
improvements_found = True
|
| 194 |
improvement_pct = ((avg_score - baseline_avg) / baseline_avg) * 100
|
| 195 |
-
|
| 196 |
-
print(f"\n✓ SOP v{entry['version']}: {entry.get('description', '')}")
|
| 197 |
print(f" Average Score: {avg_score:.3f} (+{improvement_pct:.1f}% vs baseline)")
|
| 198 |
-
|
| 199 |
# Show per-metric improvements
|
| 200 |
-
metric_names = ['Clinical Accuracy', 'Evidence Grounding', 'Actionability',
|
| 201 |
'Clarity', 'Safety & Completeness']
|
| 202 |
for i, (name, score, baseline_score) in enumerate(zip(metric_names, scores, baseline_scores)):
|
| 203 |
diff = score - baseline_score
|
| 204 |
if abs(diff) > 0.01: # Show significant changes
|
| 205 |
symbol = "↑" if diff > 0 else "↓"
|
| 206 |
print(f" {name}: {score:.3f} {symbol} ({diff:+.3f})")
|
| 207 |
-
|
| 208 |
if not improvements_found:
|
| 209 |
print("\n⚠️ No improvements found over baseline yet")
|
| 210 |
print(" Consider running more evolution cycles or adjusting mutation strategies")
|
| 211 |
-
|
| 212 |
print("\n" + "=" * 80)
|
|
|
|
| 3 |
Identifies optimal trade-offs in multi-objective optimization
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
import matplotlib
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
matplotlib.use('Agg') # Use non-interactive backend
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
|
| 14 |
|
| 15 |
+
def identify_pareto_front(gene_pool_entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 16 |
"""
|
| 17 |
Identifies non-dominated solutions (Pareto Frontier).
|
| 18 |
|
|
|
|
| 21 |
- Strictly better on AT LEAST ONE metric
|
| 22 |
"""
|
| 23 |
pareto_front = []
|
| 24 |
+
|
| 25 |
for i, candidate in enumerate(gene_pool_entries):
|
| 26 |
is_dominated = False
|
| 27 |
+
|
| 28 |
# Get candidate's 5D score vector
|
| 29 |
cand_scores = np.array(candidate['evaluation'].to_vector())
|
| 30 |
+
|
| 31 |
for j, other in enumerate(gene_pool_entries):
|
| 32 |
if i == j:
|
| 33 |
continue
|
| 34 |
+
|
| 35 |
# Get other solution's 5D vector
|
| 36 |
other_scores = np.array(other['evaluation'].to_vector())
|
| 37 |
+
|
| 38 |
# Check domination: other >= candidate on ALL, other > candidate on SOME
|
| 39 |
if np.all(other_scores >= cand_scores) and np.any(other_scores > cand_scores):
|
| 40 |
is_dominated = True
|
| 41 |
break
|
| 42 |
+
|
| 43 |
if not is_dominated:
|
| 44 |
pareto_front.append(candidate)
|
| 45 |
+
|
| 46 |
return pareto_front
|
| 47 |
|
| 48 |
|
| 49 |
+
def visualize_pareto_frontier(pareto_front: list[dict[str, Any]]):
|
| 50 |
"""
|
| 51 |
Creates two visualizations:
|
| 52 |
1. Parallel coordinates plot (5D)
|
|
|
|
| 55 |
if not pareto_front:
|
| 56 |
print("No solutions on Pareto front to visualize")
|
| 57 |
return
|
| 58 |
+
|
| 59 |
fig = plt.figure(figsize=(18, 7))
|
| 60 |
+
|
| 61 |
# --- Plot 1: Bar Chart (since pandas might not be available) ---
|
| 62 |
ax1 = plt.subplot(1, 2, 1)
|
| 63 |
+
|
| 64 |
metrics = ['Clinical\nAccuracy', 'Evidence\nGrounding', 'Actionability', 'Clarity', 'Safety']
|
| 65 |
x = np.arange(len(metrics))
|
| 66 |
width = 0.8 / len(pareto_front)
|
| 67 |
+
|
| 68 |
for idx, entry in enumerate(pareto_front):
|
| 69 |
e = entry['evaluation']
|
| 70 |
scores = [
|
|
|
|
| 74 |
e.clarity.score,
|
| 75 |
e.safety_completeness.score
|
| 76 |
]
|
| 77 |
+
|
| 78 |
offset = (idx - len(pareto_front) / 2) * width + width / 2
|
| 79 |
label = f"SOP v{entry['version']}"
|
| 80 |
ax1.bar(x + offset, scores, width, label=label, alpha=0.8)
|
| 81 |
+
|
| 82 |
ax1.set_xlabel('Metrics', fontsize=12)
|
| 83 |
ax1.set_ylabel('Score', fontsize=12)
|
| 84 |
ax1.set_title('5D Performance Comparison (Bar Chart)', fontsize=14)
|
|
|
|
| 87 |
ax1.set_ylim(0, 1.0)
|
| 88 |
ax1.legend(loc='upper left')
|
| 89 |
ax1.grid(True, alpha=0.3, axis='y')
|
| 90 |
+
|
| 91 |
# --- Plot 2: Radar Chart ---
|
| 92 |
ax2 = plt.subplot(1, 2, 2, projection='polar')
|
| 93 |
+
|
| 94 |
+
categories = ['Clinical\nAccuracy', 'Evidence\nGrounding',
|
| 95 |
'Actionability', 'Clarity', 'Safety']
|
| 96 |
num_vars = len(categories)
|
| 97 |
+
|
| 98 |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
| 99 |
angles += angles[:1]
|
| 100 |
+
|
| 101 |
for entry in pareto_front:
|
| 102 |
e = entry['evaluation']
|
| 103 |
values = [
|
|
|
|
| 108 |
e.safety_completeness.score
|
| 109 |
]
|
| 110 |
values += values[:1]
|
| 111 |
+
|
| 112 |
desc = entry.get('description', '')[:30]
|
| 113 |
label = f"SOP v{entry['version']}: {desc}"
|
| 114 |
ax2.plot(angles, values, 'o-', linewidth=2, label=label)
|
| 115 |
ax2.fill(angles, values, alpha=0.15)
|
| 116 |
+
|
| 117 |
ax2.set_xticks(angles[:-1])
|
| 118 |
ax2.set_xticklabels(categories, size=10)
|
| 119 |
ax2.set_ylim(0, 1)
|
| 120 |
ax2.set_title('5D Performance Profiles (Radar Chart)', size=14, y=1.08)
|
| 121 |
ax2.legend(loc='upper left', bbox_to_anchor=(1.2, 1.0), fontsize=9)
|
| 122 |
ax2.grid(True)
|
| 123 |
+
|
| 124 |
plt.tight_layout()
|
| 125 |
+
|
| 126 |
# Create data directory if it doesn't exist
|
| 127 |
from pathlib import Path
|
| 128 |
data_dir = Path('data')
|
| 129 |
data_dir.mkdir(exist_ok=True)
|
| 130 |
+
|
| 131 |
output_path = data_dir / 'pareto_frontier_analysis.png'
|
| 132 |
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 133 |
plt.close()
|
| 134 |
+
|
| 135 |
print(f"\n✓ Visualization saved to: {output_path}")
|
| 136 |
|
| 137 |
|
| 138 |
+
def print_pareto_summary(pareto_front: list[dict[str, Any]]):
|
| 139 |
"""Print human-readable summary of Pareto frontier"""
|
| 140 |
print("\n" + "=" * 80)
|
| 141 |
print("PARETO FRONTIER ANALYSIS")
|
| 142 |
print("=" * 80)
|
| 143 |
+
|
| 144 |
print(f"\nFound {len(pareto_front)} optimal (non-dominated) solutions:\n")
|
| 145 |
+
|
| 146 |
for entry in pareto_front:
|
| 147 |
v = entry['version']
|
| 148 |
p = entry.get('parent')
|
| 149 |
desc = entry.get('description', 'Baseline')
|
| 150 |
e = entry['evaluation']
|
| 151 |
+
|
| 152 |
print(f"SOP v{v} {f'(Child of v{p})' if p else '(Baseline)'}")
|
| 153 |
print(f" Description: {desc}")
|
| 154 |
print(f" Clinical Accuracy: {e.clinical_accuracy.score:.3f}")
|
|
|
|
| 156 |
print(f" Actionability: {e.actionability.score:.3f}")
|
| 157 |
print(f" Clarity: {e.clarity.score:.3f}")
|
| 158 |
print(f" Safety & Completeness: {e.safety_completeness.score:.3f}")
|
| 159 |
+
|
| 160 |
# Calculate average
|
| 161 |
avg_score = np.mean(e.to_vector())
|
| 162 |
print(f" Average Score: {avg_score:.3f}")
|
| 163 |
print()
|
| 164 |
+
|
| 165 |
print("=" * 80)
|
| 166 |
print("\nRECOMMENDATION:")
|
| 167 |
print("Review the visualizations and choose the SOP that best matches")
|
|
|
|
| 169 |
print("=" * 80)
|
| 170 |
|
| 171 |
|
| 172 |
+
def analyze_improvements(gene_pool_entries: list[dict[str, Any]]):
|
| 173 |
"""Analyze improvements over baseline"""
|
| 174 |
if len(gene_pool_entries) < 2:
|
| 175 |
print("\n⚠️ Not enough SOPs to analyze improvements")
|
| 176 |
return
|
| 177 |
+
|
| 178 |
baseline = gene_pool_entries[0]
|
| 179 |
baseline_scores = np.array(baseline['evaluation'].to_vector())
|
| 180 |
+
|
| 181 |
print("\n" + "=" * 80)
|
| 182 |
print("IMPROVEMENT ANALYSIS")
|
| 183 |
print("=" * 80)
|
| 184 |
+
|
| 185 |
print(f"\nBaseline (v{baseline['version']}): {baseline.get('description', 'Initial')}")
|
| 186 |
print(f" Average Score: {np.mean(baseline_scores):.3f}")
|
| 187 |
+
|
| 188 |
improvements_found = False
|
| 189 |
for entry in gene_pool_entries[1:]:
|
| 190 |
scores = np.array(entry['evaluation'].to_vector())
|
| 191 |
avg_score = np.mean(scores)
|
| 192 |
baseline_avg = np.mean(baseline_scores)
|
| 193 |
+
|
| 194 |
if avg_score > baseline_avg:
|
| 195 |
improvements_found = True
|
| 196 |
improvement_pct = ((avg_score - baseline_avg) / baseline_avg) * 100
|
| 197 |
+
|
| 198 |
+
print(f"\n✓ SOP v{entry['version']}: {entry.get('description', '')}")
|
| 199 |
print(f" Average Score: {avg_score:.3f} (+{improvement_pct:.1f}% vs baseline)")
|
| 200 |
+
|
| 201 |
# Show per-metric improvements
|
| 202 |
+
metric_names = ['Clinical Accuracy', 'Evidence Grounding', 'Actionability',
|
| 203 |
'Clarity', 'Safety & Completeness']
|
| 204 |
for i, (name, score, baseline_score) in enumerate(zip(metric_names, scores, baseline_scores)):
|
| 205 |
diff = score - baseline_score
|
| 206 |
if abs(diff) > 0.01: # Show significant changes
|
| 207 |
symbol = "↑" if diff > 0 else "↓"
|
| 208 |
print(f" {name}: {score:.3f} {symbol} ({diff:+.3f})")
|
| 209 |
+
|
| 210 |
if not improvements_found:
|
| 211 |
print("\n⚠️ No improvements found over baseline yet")
|
| 212 |
print(" Consider running more evolution cycles or adjusting mutation strategies")
|
| 213 |
+
|
| 214 |
print("\n" + "=" * 80)
|
archive/sop_evolution.py
CHANGED
|
@@ -8,9 +8,10 @@ from __future__ import annotations
|
|
| 8 |
|
| 9 |
from datetime import datetime, timedelta
|
| 10 |
|
| 11 |
-
from airflow import DAG
|
| 12 |
from airflow.operators.python import PythonOperator
|
| 13 |
|
|
|
|
|
|
|
| 14 |
default_args = {
|
| 15 |
"owner": "mediguard",
|
| 16 |
"retries": 1,
|
|
|
|
| 8 |
|
| 9 |
from datetime import datetime, timedelta
|
| 10 |
|
|
|
|
| 11 |
from airflow.operators.python import PythonOperator
|
| 12 |
|
| 13 |
+
from airflow import DAG
|
| 14 |
+
|
| 15 |
default_args = {
|
| 16 |
"owner": "mediguard",
|
| 17 |
"retries": 1,
|
{tests → archive/tests}/test_evolution_loop.py
RENAMED
|
@@ -10,20 +10,20 @@ from pathlib import Path
|
|
| 10 |
project_root = Path(__file__).parent.parent
|
| 11 |
sys.path.insert(0, str(project_root))
|
| 12 |
|
| 13 |
-
from
|
| 14 |
-
from
|
|
|
|
| 15 |
from src.config import BASELINE_SOP
|
| 16 |
-
from src.state import PatientInput, GuildState
|
| 17 |
from src.evaluation.evaluators import run_full_evaluation
|
| 18 |
from src.evolution.director import SOPGenePool, run_evolution_cycle
|
| 19 |
from src.evolution.pareto import (
|
|
|
|
| 20 |
identify_pareto_front,
|
| 21 |
-
visualize_pareto_frontier,
|
| 22 |
print_pareto_summary,
|
| 23 |
-
|
| 24 |
)
|
| 25 |
-
from
|
| 26 |
-
from
|
| 27 |
|
| 28 |
|
| 29 |
def create_test_patient() -> PatientInput:
|
|
@@ -53,8 +53,8 @@ def create_test_patient() -> PatientInput:
|
|
| 53 |
"Chloride": 102.0,
|
| 54 |
"Bicarbonate": 24.0
|
| 55 |
}
|
| 56 |
-
|
| 57 |
-
model_prediction:
|
| 58 |
'disease': 'Type 2 Diabetes',
|
| 59 |
'confidence': 0.92,
|
| 60 |
'probabilities': {
|
|
@@ -64,7 +64,7 @@ def create_test_patient() -> PatientInput:
|
|
| 64 |
},
|
| 65 |
'prediction_timestamp': '2025-01-01T10:00:00'
|
| 66 |
}
|
| 67 |
-
|
| 68 |
patient_context = {
|
| 69 |
'patient_id': 'TEST-001',
|
| 70 |
'age': 55,
|
|
@@ -74,7 +74,7 @@ def create_test_patient() -> PatientInput:
|
|
| 74 |
'current_medications': ["Metformin 500mg"],
|
| 75 |
'query': "My blood sugar has been high lately. What should I do?"
|
| 76 |
}
|
| 77 |
-
|
| 78 |
return PatientInput(
|
| 79 |
biomarkers=biomarkers,
|
| 80 |
model_prediction=model_prediction,
|
|
@@ -87,19 +87,19 @@ def main():
|
|
| 87 |
print("\n" + "=" * 80)
|
| 88 |
print("PHASE 3: SELF-IMPROVEMENT LOOP TEST")
|
| 89 |
print("=" * 80)
|
| 90 |
-
|
| 91 |
# Setup
|
| 92 |
print("\n1. Initializing system...")
|
| 93 |
guild = create_guild()
|
| 94 |
patient = create_test_patient()
|
| 95 |
-
|
| 96 |
# Initialize gene pool with baseline
|
| 97 |
print("\n2. Creating SOP Gene Pool...")
|
| 98 |
gene_pool = SOPGenePool()
|
| 99 |
-
|
| 100 |
print("\n3. Evaluating Baseline SOP...")
|
| 101 |
# Run workflow with baseline SOP
|
| 102 |
-
|
| 103 |
initial_state: GuildState = {
|
| 104 |
'patient_biomarkers': patient.biomarkers,
|
| 105 |
'model_prediction': patient.model_prediction,
|
|
@@ -113,41 +113,41 @@ def main():
|
|
| 113 |
'processing_timestamp': datetime.now().isoformat(),
|
| 114 |
'sop_version': "Baseline"
|
| 115 |
}
|
| 116 |
-
|
| 117 |
guild_state = guild.workflow.invoke(initial_state)
|
| 118 |
-
|
| 119 |
baseline_response = guild_state['final_response']
|
| 120 |
agent_outputs = guild_state['agent_outputs']
|
| 121 |
-
|
| 122 |
baseline_eval = run_full_evaluation(
|
| 123 |
final_response=baseline_response,
|
| 124 |
agent_outputs=agent_outputs,
|
| 125 |
biomarkers=patient.biomarkers
|
| 126 |
)
|
| 127 |
-
|
| 128 |
gene_pool.add(
|
| 129 |
sop=BASELINE_SOP,
|
| 130 |
evaluation=baseline_eval,
|
| 131 |
parent_version=None,
|
| 132 |
description="Baseline SOP"
|
| 133 |
)
|
| 134 |
-
|
| 135 |
print(f"\n✓ Baseline Average Score: {baseline_eval.average_score():.3f}")
|
| 136 |
print(f" Clinical Accuracy: {baseline_eval.clinical_accuracy.score:.3f}")
|
| 137 |
print(f" Evidence Grounding: {baseline_eval.evidence_grounding.score:.3f}")
|
| 138 |
print(f" Actionability: {baseline_eval.actionability.score:.3f}")
|
| 139 |
print(f" Clarity: {baseline_eval.clarity.score:.3f}")
|
| 140 |
print(f" Safety & Completeness: {baseline_eval.safety_completeness.score:.3f}")
|
| 141 |
-
|
| 142 |
# Run evolution cycles
|
| 143 |
num_cycles = 2
|
| 144 |
print(f"\n4. Running {num_cycles} Evolution Cycles...")
|
| 145 |
-
|
| 146 |
for cycle in range(1, num_cycles + 1):
|
| 147 |
print(f"\n{'─' * 80}")
|
| 148 |
print(f"EVOLUTION CYCLE {cycle}")
|
| 149 |
print(f"{'─' * 80}")
|
| 150 |
-
|
| 151 |
try:
|
| 152 |
# Create evaluation function for this cycle
|
| 153 |
def eval_func(final_response, agent_outputs, biomarkers):
|
|
@@ -156,61 +156,61 @@ def main():
|
|
| 156 |
agent_outputs=agent_outputs,
|
| 157 |
biomarkers=biomarkers
|
| 158 |
)
|
| 159 |
-
|
| 160 |
new_entries = run_evolution_cycle(
|
| 161 |
gene_pool=gene_pool,
|
| 162 |
patient_input=patient,
|
| 163 |
workflow_graph=guild.workflow,
|
| 164 |
evaluation_func=eval_func
|
| 165 |
)
|
| 166 |
-
|
| 167 |
print(f"\n✓ Cycle {cycle} complete: Added {len(new_entries)} new SOPs to gene pool")
|
| 168 |
-
|
| 169 |
for entry in new_entries:
|
| 170 |
print(f"\n SOP v{entry['version']}: {entry['description']}")
|
| 171 |
print(f" Average Score: {entry['evaluation'].average_score():.3f}")
|
| 172 |
-
|
| 173 |
except Exception as e:
|
| 174 |
print(f"\n⚠️ Cycle {cycle} encountered error: {e}")
|
| 175 |
print("Continuing to next cycle...")
|
| 176 |
-
|
| 177 |
# Show gene pool summary
|
| 178 |
print("\n5. Gene Pool Summary:")
|
| 179 |
gene_pool.summary()
|
| 180 |
-
|
| 181 |
# Pareto Analysis
|
| 182 |
print("\n6. Identifying Pareto Frontier...")
|
| 183 |
all_entries = gene_pool.gene_pool
|
| 184 |
pareto_front = identify_pareto_front(all_entries)
|
| 185 |
-
|
| 186 |
print(f"\n✓ Pareto frontier contains {len(pareto_front)} non-dominated solutions")
|
| 187 |
print_pareto_summary(pareto_front)
|
| 188 |
-
|
| 189 |
# Improvement Analysis
|
| 190 |
print("\n7. Analyzing Improvements...")
|
| 191 |
analyze_improvements(all_entries)
|
| 192 |
-
|
| 193 |
# Visualizations
|
| 194 |
print("\n8. Generating Visualizations...")
|
| 195 |
visualize_pareto_frontier(pareto_front)
|
| 196 |
-
|
| 197 |
# Final Summary
|
| 198 |
print("\n" + "=" * 80)
|
| 199 |
print("EVOLUTION TEST COMPLETE")
|
| 200 |
print("=" * 80)
|
| 201 |
-
|
| 202 |
print(f"\n✓ Total SOPs in Gene Pool: {len(all_entries)}")
|
| 203 |
print(f"✓ Pareto Optimal SOPs: {len(pareto_front)}")
|
| 204 |
-
|
| 205 |
# Find best average score
|
| 206 |
best_sop = max(all_entries, key=lambda e: e['evaluation'].average_score())
|
| 207 |
baseline_avg = baseline_eval.average_score()
|
| 208 |
best_avg = best_sop['evaluation'].average_score()
|
| 209 |
improvement = ((best_avg - baseline_avg) / baseline_avg) * 100
|
| 210 |
-
|
| 211 |
print(f"\nBest SOP: v{best_sop['version']} - {best_sop['description']}")
|
| 212 |
print(f" Average Score: {best_avg:.3f} ({improvement:+.1f}% vs baseline)")
|
| 213 |
-
|
| 214 |
print("\n✓ Visualization saved to: data/pareto_frontier_analysis.png")
|
| 215 |
print("\n" + "=" * 80)
|
| 216 |
|
|
|
|
| 10 |
project_root = Path(__file__).parent.parent
|
| 11 |
sys.path.insert(0, str(project_root))
|
| 12 |
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
from src.config import BASELINE_SOP
|
|
|
|
| 17 |
from src.evaluation.evaluators import run_full_evaluation
|
| 18 |
from src.evolution.director import SOPGenePool, run_evolution_cycle
|
| 19 |
from src.evolution.pareto import (
|
| 20 |
+
analyze_improvements,
|
| 21 |
identify_pareto_front,
|
|
|
|
| 22 |
print_pareto_summary,
|
| 23 |
+
visualize_pareto_frontier,
|
| 24 |
)
|
| 25 |
+
from src.state import GuildState, PatientInput
|
| 26 |
+
from src.workflow import create_guild
|
| 27 |
|
| 28 |
|
| 29 |
def create_test_patient() -> PatientInput:
|
|
|
|
| 53 |
"Chloride": 102.0,
|
| 54 |
"Bicarbonate": 24.0
|
| 55 |
}
|
| 56 |
+
|
| 57 |
+
model_prediction: dict[str, Any] = {
|
| 58 |
'disease': 'Type 2 Diabetes',
|
| 59 |
'confidence': 0.92,
|
| 60 |
'probabilities': {
|
|
|
|
| 64 |
},
|
| 65 |
'prediction_timestamp': '2025-01-01T10:00:00'
|
| 66 |
}
|
| 67 |
+
|
| 68 |
patient_context = {
|
| 69 |
'patient_id': 'TEST-001',
|
| 70 |
'age': 55,
|
|
|
|
| 74 |
'current_medications': ["Metformin 500mg"],
|
| 75 |
'query': "My blood sugar has been high lately. What should I do?"
|
| 76 |
}
|
| 77 |
+
|
| 78 |
return PatientInput(
|
| 79 |
biomarkers=biomarkers,
|
| 80 |
model_prediction=model_prediction,
|
|
|
|
| 87 |
print("\n" + "=" * 80)
|
| 88 |
print("PHASE 3: SELF-IMPROVEMENT LOOP TEST")
|
| 89 |
print("=" * 80)
|
| 90 |
+
|
| 91 |
# Setup
|
| 92 |
print("\n1. Initializing system...")
|
| 93 |
guild = create_guild()
|
| 94 |
patient = create_test_patient()
|
| 95 |
+
|
| 96 |
# Initialize gene pool with baseline
|
| 97 |
print("\n2. Creating SOP Gene Pool...")
|
| 98 |
gene_pool = SOPGenePool()
|
| 99 |
+
|
| 100 |
print("\n3. Evaluating Baseline SOP...")
|
| 101 |
# Run workflow with baseline SOP
|
| 102 |
+
|
| 103 |
initial_state: GuildState = {
|
| 104 |
'patient_biomarkers': patient.biomarkers,
|
| 105 |
'model_prediction': patient.model_prediction,
|
|
|
|
| 113 |
'processing_timestamp': datetime.now().isoformat(),
|
| 114 |
'sop_version': "Baseline"
|
| 115 |
}
|
| 116 |
+
|
| 117 |
guild_state = guild.workflow.invoke(initial_state)
|
| 118 |
+
|
| 119 |
baseline_response = guild_state['final_response']
|
| 120 |
agent_outputs = guild_state['agent_outputs']
|
| 121 |
+
|
| 122 |
baseline_eval = run_full_evaluation(
|
| 123 |
final_response=baseline_response,
|
| 124 |
agent_outputs=agent_outputs,
|
| 125 |
biomarkers=patient.biomarkers
|
| 126 |
)
|
| 127 |
+
|
| 128 |
gene_pool.add(
|
| 129 |
sop=BASELINE_SOP,
|
| 130 |
evaluation=baseline_eval,
|
| 131 |
parent_version=None,
|
| 132 |
description="Baseline SOP"
|
| 133 |
)
|
| 134 |
+
|
| 135 |
print(f"\n✓ Baseline Average Score: {baseline_eval.average_score():.3f}")
|
| 136 |
print(f" Clinical Accuracy: {baseline_eval.clinical_accuracy.score:.3f}")
|
| 137 |
print(f" Evidence Grounding: {baseline_eval.evidence_grounding.score:.3f}")
|
| 138 |
print(f" Actionability: {baseline_eval.actionability.score:.3f}")
|
| 139 |
print(f" Clarity: {baseline_eval.clarity.score:.3f}")
|
| 140 |
print(f" Safety & Completeness: {baseline_eval.safety_completeness.score:.3f}")
|
| 141 |
+
|
| 142 |
# Run evolution cycles
|
| 143 |
num_cycles = 2
|
| 144 |
print(f"\n4. Running {num_cycles} Evolution Cycles...")
|
| 145 |
+
|
| 146 |
for cycle in range(1, num_cycles + 1):
|
| 147 |
print(f"\n{'─' * 80}")
|
| 148 |
print(f"EVOLUTION CYCLE {cycle}")
|
| 149 |
print(f"{'─' * 80}")
|
| 150 |
+
|
| 151 |
try:
|
| 152 |
# Create evaluation function for this cycle
|
| 153 |
def eval_func(final_response, agent_outputs, biomarkers):
|
|
|
|
| 156 |
agent_outputs=agent_outputs,
|
| 157 |
biomarkers=biomarkers
|
| 158 |
)
|
| 159 |
+
|
| 160 |
new_entries = run_evolution_cycle(
|
| 161 |
gene_pool=gene_pool,
|
| 162 |
patient_input=patient,
|
| 163 |
workflow_graph=guild.workflow,
|
| 164 |
evaluation_func=eval_func
|
| 165 |
)
|
| 166 |
+
|
| 167 |
print(f"\n✓ Cycle {cycle} complete: Added {len(new_entries)} new SOPs to gene pool")
|
| 168 |
+
|
| 169 |
for entry in new_entries:
|
| 170 |
print(f"\n SOP v{entry['version']}: {entry['description']}")
|
| 171 |
print(f" Average Score: {entry['evaluation'].average_score():.3f}")
|
| 172 |
+
|
| 173 |
except Exception as e:
|
| 174 |
print(f"\n⚠️ Cycle {cycle} encountered error: {e}")
|
| 175 |
print("Continuing to next cycle...")
|
| 176 |
+
|
| 177 |
# Show gene pool summary
|
| 178 |
print("\n5. Gene Pool Summary:")
|
| 179 |
gene_pool.summary()
|
| 180 |
+
|
| 181 |
# Pareto Analysis
|
| 182 |
print("\n6. Identifying Pareto Frontier...")
|
| 183 |
all_entries = gene_pool.gene_pool
|
| 184 |
pareto_front = identify_pareto_front(all_entries)
|
| 185 |
+
|
| 186 |
print(f"\n✓ Pareto frontier contains {len(pareto_front)} non-dominated solutions")
|
| 187 |
print_pareto_summary(pareto_front)
|
| 188 |
+
|
| 189 |
# Improvement Analysis
|
| 190 |
print("\n7. Analyzing Improvements...")
|
| 191 |
analyze_improvements(all_entries)
|
| 192 |
+
|
| 193 |
# Visualizations
|
| 194 |
print("\n8. Generating Visualizations...")
|
| 195 |
visualize_pareto_frontier(pareto_front)
|
| 196 |
+
|
| 197 |
# Final Summary
|
| 198 |
print("\n" + "=" * 80)
|
| 199 |
print("EVOLUTION TEST COMPLETE")
|
| 200 |
print("=" * 80)
|
| 201 |
+
|
| 202 |
print(f"\n✓ Total SOPs in Gene Pool: {len(all_entries)}")
|
| 203 |
print(f"✓ Pareto Optimal SOPs: {len(pareto_front)}")
|
| 204 |
+
|
| 205 |
# Find best average score
|
| 206 |
best_sop = max(all_entries, key=lambda e: e['evaluation'].average_score())
|
| 207 |
baseline_avg = baseline_eval.average_score()
|
| 208 |
best_avg = best_sop['evaluation'].average_score()
|
| 209 |
improvement = ((best_avg - baseline_avg) / baseline_avg) * 100
|
| 210 |
+
|
| 211 |
print(f"\nBest SOP: v{best_sop['version']} - {best_sop['description']}")
|
| 212 |
print(f" Average Score: {best_avg:.3f} ({improvement:+.1f}% vs baseline)")
|
| 213 |
+
|
| 214 |
print("\n✓ Visualization saved to: data/pareto_frontier_analysis.png")
|
| 215 |
print("\n" + "=" * 80)
|
| 216 |
|
{tests → archive/tests}/test_evolution_quick.py
RENAMED
|
@@ -5,6 +5,7 @@ Tests gene pool, diagnostician, and architect without full workflow
|
|
| 5 |
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
|
|
|
| 8 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 9 |
|
| 10 |
from src.config import BASELINE_SOP
|
|
@@ -17,11 +18,11 @@ def main():
|
|
| 17 |
print("\n" + "=" * 80)
|
| 18 |
print("QUICK PHASE 3 TEST")
|
| 19 |
print("=" * 80)
|
| 20 |
-
|
| 21 |
# Test 1: Gene Pool
|
| 22 |
print("\n1. Testing Gene Pool...")
|
| 23 |
gene_pool = SOPGenePool()
|
| 24 |
-
|
| 25 |
# Create mock evaluation (baseline with low clarity)
|
| 26 |
baseline_eval = EvaluationResult(
|
| 27 |
clinical_accuracy=GradedScore(score=0.95, reasoning="Accurate"),
|
|
@@ -30,48 +31,48 @@ def main():
|
|
| 30 |
clarity=GradedScore(score=0.75, reasoning="Could be clearer"),
|
| 31 |
safety_completeness=GradedScore(score=1.0, reasoning="Complete")
|
| 32 |
)
|
| 33 |
-
|
| 34 |
gene_pool.add(
|
| 35 |
sop=BASELINE_SOP,
|
| 36 |
evaluation=baseline_eval,
|
| 37 |
parent_version=None,
|
| 38 |
description="Baseline SOP"
|
| 39 |
)
|
| 40 |
-
|
| 41 |
-
print(
|
| 42 |
print(f" Average score: {baseline_eval.average_score():.3f}")
|
| 43 |
-
|
| 44 |
# Test 2: Performance Diagnostician
|
| 45 |
print("\n2. Testing Performance Diagnostician...")
|
| 46 |
diagnosis = performance_diagnostician(baseline_eval)
|
| 47 |
-
|
| 48 |
-
print(
|
| 49 |
print(f" Primary weakness: {diagnosis.primary_weakness}")
|
| 50 |
print(f" Root cause: {diagnosis.root_cause_analysis[:100]}...")
|
| 51 |
print(f" Recommendation: {diagnosis.recommendation[:100]}...")
|
| 52 |
-
|
| 53 |
# Test 3: SOP Architect
|
| 54 |
print("\n3. Testing SOP Architect...")
|
| 55 |
evolved_sops = sop_architect(diagnosis, BASELINE_SOP)
|
| 56 |
-
|
| 57 |
print(f"\n✓ Generated {len(evolved_sops.mutations)} mutations")
|
| 58 |
for i, mutation in enumerate(evolved_sops.mutations, 1):
|
| 59 |
print(f"\n Mutation {i}: {mutation.description}")
|
| 60 |
print(f" Disease explainer K: {mutation.disease_explainer_k}")
|
| 61 |
print(f" Detail level: {mutation.explainer_detail_level}")
|
| 62 |
print(f" Citations required: {mutation.require_pdf_citations}")
|
| 63 |
-
|
| 64 |
# Test 4: Gene Pool Summary
|
| 65 |
print("\n4. Gene Pool Summary:")
|
| 66 |
gene_pool.summary()
|
| 67 |
-
|
| 68 |
# Test 5: Average score method
|
| 69 |
print("\n5. Testing average_score method...")
|
| 70 |
avg = baseline_eval.average_score()
|
| 71 |
print(f"✓ Average score calculation: {avg:.3f}")
|
| 72 |
vector = baseline_eval.to_vector()
|
| 73 |
print(f"✓ Score vector: {[f'{s:.2f}' for s in vector]}")
|
| 74 |
-
|
| 75 |
print("\n" + "=" * 80)
|
| 76 |
print("QUICK TEST COMPLETE")
|
| 77 |
print("=" * 80)
|
|
|
|
| 5 |
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
+
|
| 9 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 10 |
|
| 11 |
from src.config import BASELINE_SOP
|
|
|
|
| 18 |
print("\n" + "=" * 80)
|
| 19 |
print("QUICK PHASE 3 TEST")
|
| 20 |
print("=" * 80)
|
| 21 |
+
|
| 22 |
# Test 1: Gene Pool
|
| 23 |
print("\n1. Testing Gene Pool...")
|
| 24 |
gene_pool = SOPGenePool()
|
| 25 |
+
|
| 26 |
# Create mock evaluation (baseline with low clarity)
|
| 27 |
baseline_eval = EvaluationResult(
|
| 28 |
clinical_accuracy=GradedScore(score=0.95, reasoning="Accurate"),
|
|
|
|
| 31 |
clarity=GradedScore(score=0.75, reasoning="Could be clearer"),
|
| 32 |
safety_completeness=GradedScore(score=1.0, reasoning="Complete")
|
| 33 |
)
|
| 34 |
+
|
| 35 |
gene_pool.add(
|
| 36 |
sop=BASELINE_SOP,
|
| 37 |
evaluation=baseline_eval,
|
| 38 |
parent_version=None,
|
| 39 |
description="Baseline SOP"
|
| 40 |
)
|
| 41 |
+
|
| 42 |
+
print("✓ Gene pool initialized with 1 SOP")
|
| 43 |
print(f" Average score: {baseline_eval.average_score():.3f}")
|
| 44 |
+
|
| 45 |
# Test 2: Performance Diagnostician
|
| 46 |
print("\n2. Testing Performance Diagnostician...")
|
| 47 |
diagnosis = performance_diagnostician(baseline_eval)
|
| 48 |
+
|
| 49 |
+
print("✓ Diagnosis complete")
|
| 50 |
print(f" Primary weakness: {diagnosis.primary_weakness}")
|
| 51 |
print(f" Root cause: {diagnosis.root_cause_analysis[:100]}...")
|
| 52 |
print(f" Recommendation: {diagnosis.recommendation[:100]}...")
|
| 53 |
+
|
| 54 |
# Test 3: SOP Architect
|
| 55 |
print("\n3. Testing SOP Architect...")
|
| 56 |
evolved_sops = sop_architect(diagnosis, BASELINE_SOP)
|
| 57 |
+
|
| 58 |
print(f"\n✓ Generated {len(evolved_sops.mutations)} mutations")
|
| 59 |
for i, mutation in enumerate(evolved_sops.mutations, 1):
|
| 60 |
print(f"\n Mutation {i}: {mutation.description}")
|
| 61 |
print(f" Disease explainer K: {mutation.disease_explainer_k}")
|
| 62 |
print(f" Detail level: {mutation.explainer_detail_level}")
|
| 63 |
print(f" Citations required: {mutation.require_pdf_citations}")
|
| 64 |
+
|
| 65 |
# Test 4: Gene Pool Summary
|
| 66 |
print("\n4. Gene Pool Summary:")
|
| 67 |
gene_pool.summary()
|
| 68 |
+
|
| 69 |
# Test 5: Average score method
|
| 70 |
print("\n5. Testing average_score method...")
|
| 71 |
avg = baseline_eval.average_score()
|
| 72 |
print(f"✓ Average score calculation: {avg:.3f}")
|
| 73 |
vector = baseline_eval.to_vector()
|
| 74 |
print(f"✓ Score vector: {[f'{s:.2f}' for s in vector]}")
|
| 75 |
+
|
| 76 |
print("\n" + "=" * 80)
|
| 77 |
print("QUICK TEST COMPLETE")
|
| 78 |
print("=" * 80)
|
docker-compose.yml
CHANGED
|
@@ -143,6 +143,26 @@ services:
|
|
| 143 |
# count: 1
|
| 144 |
# capabilities: [gpu]
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
# -----------------------------------------------------------------------
|
| 147 |
# Observability
|
| 148 |
# -----------------------------------------------------------------------
|
|
|
|
| 143 |
# count: 1
|
| 144 |
# capabilities: [gpu]
|
| 145 |
|
| 146 |
+
airflow:
|
| 147 |
+
image: apache/airflow:2.8.2
|
| 148 |
+
container_name: mediguard-airflow
|
| 149 |
+
environment:
|
| 150 |
+
- AIRFLOW__CORE__LOAD_EXAMPLES=false
|
| 151 |
+
- AIRFLOW__CORE__EXECUTOR=LocalExecutor
|
| 152 |
+
- AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://${POSTGRES__USER:-mediguard}:${POSTGRES__PASSWORD:-mediguard_secret}@postgres:5432/${POSTGRES__DATABASE:-mediguard}
|
| 153 |
+
command: standalone
|
| 154 |
+
ports:
|
| 155 |
+
- "${AIRFLOW_PORT:-8080}:8080"
|
| 156 |
+
volumes:
|
| 157 |
+
- ./airflow/dags:/opt/airflow/dags:ro
|
| 158 |
+
- ./data/medical_pdfs:/app/data/medical_pdfs:ro
|
| 159 |
+
- .:/app:ro
|
| 160 |
+
working_dir: /app
|
| 161 |
+
depends_on:
|
| 162 |
+
postgres:
|
| 163 |
+
condition: service_healthy
|
| 164 |
+
restart: unless-stopped
|
| 165 |
+
|
| 166 |
# -----------------------------------------------------------------------
|
| 167 |
# Observability
|
| 168 |
# -----------------------------------------------------------------------
|
gradio_launcher.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MediGuard AI — Gradio Launcher wrapper.
|
| 3 |
+
|
| 4 |
+
Spawns the Gradio frontend UI on the correct designated port (7861), separating
|
| 5 |
+
the frontend runner from the production API layer entirely.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
# Ensure project root is in path
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 16 |
+
|
| 17 |
+
from src.gradio_app import launch_gradio
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
port = int(os.environ.get("GRADIO_PORT", 7861))
|
| 23 |
+
logging.info("Starting Gradio Web UI Launcher on port %d...", port)
|
| 24 |
+
launch_gradio(share=False, server_port=port)
|
huggingface/app.py
CHANGED
|
@@ -37,7 +37,7 @@ import sys
|
|
| 37 |
import time
|
| 38 |
import traceback
|
| 39 |
from pathlib import Path
|
| 40 |
-
from typing import Any
|
| 41 |
|
| 42 |
# Ensure project root is in path
|
| 43 |
_project_root = str(Path(__file__).parent.parent)
|
|
@@ -114,7 +114,7 @@ def setup_llm_provider():
|
|
| 114 |
"""
|
| 115 |
groq_key, google_key = get_api_keys()
|
| 116 |
provider = None
|
| 117 |
-
|
| 118 |
if groq_key:
|
| 119 |
os.environ["LLM_PROVIDER"] = "groq"
|
| 120 |
os.environ["GROQ_API_KEY"] = groq_key
|
|
@@ -127,18 +127,18 @@ def setup_llm_provider():
|
|
| 127 |
os.environ["GEMINI_MODEL"] = get_gemini_model()
|
| 128 |
provider = "gemini"
|
| 129 |
logger.info(f"Configured Gemini provider with model: {get_gemini_model()}")
|
| 130 |
-
|
| 131 |
# Set up embedding provider
|
| 132 |
embedding_provider = get_embedding_provider()
|
| 133 |
os.environ["EMBEDDING_PROVIDER"] = embedding_provider
|
| 134 |
-
|
| 135 |
# If Jina is configured, set the API key
|
| 136 |
jina_key = get_jina_api_key()
|
| 137 |
if jina_key:
|
| 138 |
os.environ["JINA_API_KEY"] = jina_key
|
| 139 |
os.environ["EMBEDDING__JINA_API_KEY"] = jina_key
|
| 140 |
logger.info("Jina embeddings configured")
|
| 141 |
-
|
| 142 |
# Set up Langfuse if enabled
|
| 143 |
if is_langfuse_enabled():
|
| 144 |
os.environ["LANGFUSE__ENABLED"] = "true"
|
|
@@ -147,7 +147,7 @@ def setup_llm_provider():
|
|
| 147 |
if val:
|
| 148 |
os.environ[var] = val
|
| 149 |
logger.info("Langfuse observability enabled")
|
| 150 |
-
|
| 151 |
return provider
|
| 152 |
|
| 153 |
|
|
@@ -192,21 +192,21 @@ def reset_guild():
|
|
| 192 |
def get_guild():
|
| 193 |
"""Lazy initialization of the Clinical Insight Guild."""
|
| 194 |
global _guild, _guild_error, _guild_provider
|
| 195 |
-
|
| 196 |
# Check if we need to reinitialize (provider changed)
|
| 197 |
current_provider = os.getenv("LLM_PROVIDER")
|
| 198 |
if _guild_provider and _guild_provider != current_provider:
|
| 199 |
logger.info(f"Provider changed from {_guild_provider} to {current_provider}, reinitializing...")
|
| 200 |
reset_guild()
|
| 201 |
-
|
| 202 |
if _guild is not None:
|
| 203 |
return _guild
|
| 204 |
-
|
| 205 |
if _guild_error is not None:
|
| 206 |
# Don't cache errors forever - allow retry
|
| 207 |
logger.warning("Previous initialization failed, retrying...")
|
| 208 |
_guild_error = None
|
| 209 |
-
|
| 210 |
try:
|
| 211 |
logger.info("Initializing Clinical Insight Guild...")
|
| 212 |
logger.info(f" LLM_PROVIDER: {os.getenv('LLM_PROVIDER', 'not set')}")
|
|
@@ -214,17 +214,17 @@ def get_guild():
|
|
| 214 |
logger.info(f" GOOGLE_API_KEY: {'✓ set' if os.getenv('GOOGLE_API_KEY') else '✗ not set'}")
|
| 215 |
logger.info(f" EMBEDDING_PROVIDER: {os.getenv('EMBEDDING_PROVIDER', 'huggingface')}")
|
| 216 |
logger.info(f" JINA_API_KEY: {'✓ set' if os.getenv('JINA_API_KEY') else '✗ not set'}")
|
| 217 |
-
|
| 218 |
start = time.time()
|
| 219 |
-
|
| 220 |
from src.workflow import create_guild
|
| 221 |
_guild = create_guild()
|
| 222 |
_guild_provider = current_provider
|
| 223 |
-
|
| 224 |
elapsed = time.time() - start
|
| 225 |
logger.info(f"Guild initialized in {elapsed:.1f}s")
|
| 226 |
return _guild
|
| 227 |
-
|
| 228 |
except Exception as exc:
|
| 229 |
logger.error(f"Failed to initialize guild: {exc}")
|
| 230 |
_guild_error = exc
|
|
@@ -237,11 +237,8 @@ def get_guild():
|
|
| 237 |
|
| 238 |
# Import shared parsing and prediction logic
|
| 239 |
from src.shared_utils import (
|
| 240 |
-
parse_biomarkers,
|
| 241 |
get_primary_prediction,
|
| 242 |
-
|
| 243 |
-
severity_to_emoji,
|
| 244 |
-
format_confidence_percent,
|
| 245 |
)
|
| 246 |
|
| 247 |
|
|
@@ -267,10 +264,10 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 267 |
<p style="margin: 8px 0 0 0; color: #64748b;">Please enter biomarkers to analyze.</p>
|
| 268 |
</div>
|
| 269 |
"""
|
| 270 |
-
|
| 271 |
# Check API key dynamically (HF injects secrets after startup)
|
| 272 |
groq_key, google_key = get_api_keys()
|
| 273 |
-
|
| 274 |
if not groq_key and not google_key:
|
| 275 |
return "", "", """
|
| 276 |
<div style="background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); border: 1px solid #ef4444; border-radius: 10px; padding: 16px;">
|
|
@@ -297,15 +294,15 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 297 |
</details>
|
| 298 |
</div>
|
| 299 |
"""
|
| 300 |
-
|
| 301 |
# Setup provider based on available key
|
| 302 |
provider = setup_llm_provider()
|
| 303 |
logger.info(f"Using LLM provider: {provider}")
|
| 304 |
-
|
| 305 |
try:
|
| 306 |
progress(0.1, desc="📝 Parsing biomarkers...")
|
| 307 |
biomarkers = parse_biomarkers(input_text)
|
| 308 |
-
|
| 309 |
if not biomarkers:
|
| 310 |
return "", "", """
|
| 311 |
<div style="background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); border: 1px solid #fbbf24; border-radius: 10px; padding: 16px;">
|
|
@@ -317,42 +314,42 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 317 |
</ul>
|
| 318 |
</div>
|
| 319 |
"""
|
| 320 |
-
|
| 321 |
progress(0.2, desc="🔧 Initializing AI agents...")
|
| 322 |
-
|
| 323 |
# Initialize guild
|
| 324 |
guild = get_guild()
|
| 325 |
-
|
| 326 |
# Prepare input
|
| 327 |
from src.state import PatientInput
|
| 328 |
-
|
| 329 |
# Auto-generate prediction based on common patterns
|
| 330 |
prediction = auto_predict(biomarkers)
|
| 331 |
-
|
| 332 |
patient_input = PatientInput(
|
| 333 |
biomarkers=biomarkers,
|
| 334 |
model_prediction=prediction,
|
| 335 |
patient_context={"patient_id": "HF_User", "source": "huggingface_spaces"}
|
| 336 |
)
|
| 337 |
-
|
| 338 |
progress(0.4, desc="🤖 Running Clinical Insight Guild...")
|
| 339 |
-
|
| 340 |
# Run analysis
|
| 341 |
start = time.time()
|
| 342 |
result = guild.run(patient_input)
|
| 343 |
elapsed = time.time() - start
|
| 344 |
-
|
| 345 |
progress(0.9, desc="✨ Formatting results...")
|
| 346 |
-
|
| 347 |
# Extract response
|
| 348 |
final_response = result.get("final_response", {})
|
| 349 |
-
|
| 350 |
# Format summary
|
| 351 |
summary = format_summary(final_response, elapsed)
|
| 352 |
-
|
| 353 |
# Format details
|
| 354 |
details = json.dumps(final_response, indent=2, default=str)
|
| 355 |
-
|
| 356 |
status = f"""
|
| 357 |
<div style="background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%); border: 1px solid #10b981; border-radius: 10px; padding: 12px; display: flex; align-items: center; gap: 10px;">
|
| 358 |
<span style="font-size: 1.5em;">✅</span>
|
|
@@ -362,9 +359,9 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
|
|
| 362 |
</div>
|
| 363 |
</div>
|
| 364 |
"""
|
| 365 |
-
|
| 366 |
return summary, details, status
|
| 367 |
-
|
| 368 |
except Exception as exc:
|
| 369 |
logger.error(f"Analysis error: {exc}", exc_info=True)
|
| 370 |
error_msg = f"""
|
|
@@ -384,14 +381,14 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 384 |
"""Format the analysis response as clean markdown with black text."""
|
| 385 |
if not response:
|
| 386 |
return "❌ **No analysis results available.**"
|
| 387 |
-
|
| 388 |
parts = []
|
| 389 |
-
|
| 390 |
# Header with primary finding and confidence
|
| 391 |
primary = response.get("primary_finding", "Analysis Complete")
|
| 392 |
confidence = response.get("confidence", {})
|
| 393 |
conf_score = confidence.get("overall_score", 0) if isinstance(confidence, dict) else 0
|
| 394 |
-
|
| 395 |
# Determine severity
|
| 396 |
severity = response.get("severity", "low")
|
| 397 |
severity_config = {
|
|
@@ -401,14 +398,14 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 401 |
"low": ("🟢", "#16a34a", "#f0fdf4")
|
| 402 |
}
|
| 403 |
emoji, color, bg_color = severity_config.get(severity, severity_config["low"])
|
| 404 |
-
|
| 405 |
# Build confidence display
|
| 406 |
conf_badge = ""
|
| 407 |
if conf_score:
|
| 408 |
conf_pct = int(conf_score * 100)
|
| 409 |
conf_color = "#16a34a" if conf_pct >= 80 else "#ca8a04" if conf_pct >= 60 else "#dc2626"
|
| 410 |
conf_badge = f'<span style="background: {conf_color}; color: white; padding: 4px 12px; border-radius: 20px; font-size: 0.85em; margin-left: 12px;">{conf_pct}% confidence</span>'
|
| 411 |
-
|
| 412 |
parts.append(f"""
|
| 413 |
<div style="background: linear-gradient(135deg, {bg_color} 0%, white 100%); border-left: 4px solid {color}; border-radius: 12px; padding: 20px; margin-bottom: 20px;">
|
| 414 |
<div style="display: flex; align-items: center; flex-wrap: wrap;">
|
|
@@ -417,7 +414,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 417 |
{conf_badge}
|
| 418 |
</div>
|
| 419 |
</div>""")
|
| 420 |
-
|
| 421 |
# Critical Alerts
|
| 422 |
alerts = response.get("safety_alerts", [])
|
| 423 |
if alerts:
|
|
@@ -427,7 +424,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 427 |
alert_items += f'<li><strong>{alert.get("alert_type", "Alert")}:</strong> {alert.get("message", "")}</li>'
|
| 428 |
else:
|
| 429 |
alert_items += f'<li>{alert}</li>'
|
| 430 |
-
|
| 431 |
parts.append(f"""
|
| 432 |
<div style="background: linear-gradient(135deg, #fef2f2 0%, #fee2e2 100%); border: 1px solid #fecaca; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
| 433 |
<h4 style="margin: 0 0 12px 0; color: #dc2626; display: flex; align-items: center; gap: 8px;">
|
|
@@ -436,7 +433,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 436 |
<ul style="margin: 0; padding-left: 20px; color: #991b1b;">{alert_items}</ul>
|
| 437 |
</div>
|
| 438 |
""")
|
| 439 |
-
|
| 440 |
# Key Findings
|
| 441 |
findings = response.get("key_findings", [])
|
| 442 |
if findings:
|
|
@@ -447,7 +444,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 447 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{finding_items}</ul>
|
| 448 |
</div>
|
| 449 |
""")
|
| 450 |
-
|
| 451 |
# Biomarker Flags - as a visual grid
|
| 452 |
flags = response.get("biomarker_flags", [])
|
| 453 |
if flags and len(flags) > 0:
|
|
@@ -460,7 +457,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 460 |
continue
|
| 461 |
status = flag.get("status", "normal").lower()
|
| 462 |
value = flag.get("value", flag.get("result", "N/A"))
|
| 463 |
-
|
| 464 |
status_styles = {
|
| 465 |
"critical": ("🔴", "#dc2626", "#fef2f2"),
|
| 466 |
"high": ("🔴", "#dc2626", "#fef2f2"),
|
|
@@ -469,7 +466,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 469 |
"normal": ("🟢", "#16a34a", "#f0fdf4")
|
| 470 |
}
|
| 471 |
s_emoji, s_color, s_bg = status_styles.get(status, status_styles["normal"])
|
| 472 |
-
|
| 473 |
flag_cards += f"""
|
| 474 |
<div style="background: {s_bg}; border: 1px solid {s_color}33; border-radius: 8px; padding: 12px; text-align: center;">
|
| 475 |
<div style="font-size: 1.2em;">{s_emoji}</div>
|
|
@@ -478,7 +475,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 478 |
<div style="font-size: 0.75em; color: #64748b; text-transform: capitalize;">{status}</div>
|
| 479 |
</div>
|
| 480 |
"""
|
| 481 |
-
|
| 482 |
if flag_cards: # Only show section if we have cards
|
| 483 |
parts.append(f"""
|
| 484 |
<div style="margin-bottom: 16px;">
|
|
@@ -488,11 +485,11 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 488 |
</div>
|
| 489 |
</div>
|
| 490 |
""")
|
| 491 |
-
|
| 492 |
# Recommendations - organized sections
|
| 493 |
recs = response.get("recommendations", {})
|
| 494 |
rec_sections = ""
|
| 495 |
-
|
| 496 |
immediate = recs.get("immediate_actions", []) if isinstance(recs, dict) else []
|
| 497 |
if immediate and len(immediate) > 0:
|
| 498 |
items = "".join([f'<li style="margin-bottom: 6px;">{str(a).strip()}</li>' for a in immediate[:3]])
|
|
@@ -502,7 +499,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 502 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
|
| 503 |
</div>
|
| 504 |
"""
|
| 505 |
-
|
| 506 |
lifestyle = recs.get("lifestyle_modifications", []) if isinstance(recs, dict) else []
|
| 507 |
if lifestyle and len(lifestyle) > 0:
|
| 508 |
items = "".join([f'<li style="margin-bottom: 6px;">{str(m).strip()}</li>' for m in lifestyle[:3]])
|
|
@@ -512,7 +509,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 512 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
|
| 513 |
</div>
|
| 514 |
"""
|
| 515 |
-
|
| 516 |
followup = recs.get("follow_up", []) if isinstance(recs, dict) else []
|
| 517 |
if followup and len(followup) > 0:
|
| 518 |
items = "".join([f'<li style="margin-bottom: 6px;">{str(f).strip()}</li>' for f in followup[:3]])
|
|
@@ -522,10 +519,10 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 522 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
|
| 523 |
</div>
|
| 524 |
"""
|
| 525 |
-
|
| 526 |
# Add default recommendations if none provided
|
| 527 |
if not rec_sections:
|
| 528 |
-
rec_sections =
|
| 529 |
<div style="margin-bottom: 12px;">
|
| 530 |
<h5 style="margin: 0 0 8px 0; color: #2563eb;">📋 General Recommendations</h5>
|
| 531 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">
|
|
@@ -535,7 +532,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 535 |
</ul>
|
| 536 |
</div>
|
| 537 |
"""
|
| 538 |
-
|
| 539 |
if rec_sections:
|
| 540 |
parts.append(f"""
|
| 541 |
<div style="background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
|
@@ -543,7 +540,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 543 |
{rec_sections}
|
| 544 |
</div>
|
| 545 |
""")
|
| 546 |
-
|
| 547 |
# Disease Explanation
|
| 548 |
explanation = response.get("disease_explanation", {})
|
| 549 |
if explanation and isinstance(explanation, dict):
|
|
@@ -555,7 +552,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 555 |
<p style="margin: 0; color: #475569; line-height: 1.6;">{pathophys[:600]}{'...' if len(pathophys) > 600 else ''}</p>
|
| 556 |
</div>
|
| 557 |
""")
|
| 558 |
-
|
| 559 |
# Conversational Summary
|
| 560 |
conv_summary = response.get("conversational_summary", "")
|
| 561 |
if conv_summary:
|
|
@@ -565,7 +562,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 565 |
<p style="margin: 0; color: #475569; line-height: 1.6;">{conv_summary[:1000]}</p>
|
| 566 |
</div>
|
| 567 |
""")
|
| 568 |
-
|
| 569 |
# Footer
|
| 570 |
parts.append(f"""
|
| 571 |
<div style="border-top: 1px solid #e2e8f0; padding-top: 16px; margin-top: 8px; text-align: center;">
|
|
@@ -577,7 +574,7 @@ def format_summary(response: dict, elapsed: float) -> str:
|
|
| 577 |
</p>
|
| 578 |
</div>
|
| 579 |
""")
|
| 580 |
-
|
| 581 |
return "\n".join(parts)
|
| 582 |
|
| 583 |
|
|
@@ -606,10 +603,10 @@ def _get_rag_service():
|
|
| 606 |
_rag_service_error = None
|
| 607 |
|
| 608 |
try:
|
|
|
|
| 609 |
from src.services.agents.agentic_rag import AgenticRAGService
|
| 610 |
from src.services.agents.context import AgenticContext
|
| 611 |
from src.services.retrieval.factory import make_retriever
|
| 612 |
-
from src.llm_config import get_synthesizer
|
| 613 |
|
| 614 |
llm = get_synthesizer()
|
| 615 |
retriever = make_retriever() # auto-detects FAISS
|
|
@@ -637,8 +634,8 @@ def _get_rag_service():
|
|
| 637 |
|
| 638 |
def _fallback_qa(question: str, context_text: str = "") -> str:
|
| 639 |
"""Direct retriever+LLM fallback when agentic pipeline is unavailable."""
|
| 640 |
-
from src.services.retrieval.factory import make_retriever
|
| 641 |
from src.llm_config import get_synthesizer
|
|
|
|
| 642 |
|
| 643 |
retriever = make_retriever()
|
| 644 |
search_query = f"{context_text} {question}" if context_text.strip() else question
|
|
@@ -727,41 +724,53 @@ def answer_medical_question(
|
|
| 727 |
|
| 728 |
except Exception as exc:
|
| 729 |
logger.exception(f"Q&A error: {exc}")
|
| 730 |
-
error_msg = f"❌ Error: {
|
| 731 |
history = (chat_history or []) + [(question, error_msg)]
|
| 732 |
return error_msg, history
|
| 733 |
|
| 734 |
|
| 735 |
-
def streaming_answer(question: str, context: str
|
| 736 |
"""Stream answer using the full agentic RAG pipeline.
|
| 737 |
Falls back to direct retriever+LLM if the pipeline is unavailable.
|
| 738 |
"""
|
|
|
|
| 739 |
if not question.strip():
|
| 740 |
-
yield
|
| 741 |
return
|
| 742 |
|
| 743 |
-
|
|
|
|
| 744 |
if not groq_key and not google_key:
|
| 745 |
-
|
|
|
|
| 746 |
return
|
| 747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
setup_llm_provider()
|
| 749 |
|
| 750 |
try:
|
| 751 |
-
|
|
|
|
| 752 |
|
| 753 |
start_time = time.time()
|
| 754 |
|
| 755 |
rag_service = _get_rag_service()
|
| 756 |
if rag_service is not None:
|
| 757 |
-
|
|
|
|
| 758 |
result = rag_service.ask(query=question, patient_context=context)
|
| 759 |
answer = result.get("final_answer", "")
|
| 760 |
guardrail = result.get("guardrail_score")
|
| 761 |
docs_relevant = len(result.get("relevant_documents", []))
|
| 762 |
docs_retrieved = len(result.get("retrieved_documents", []))
|
| 763 |
else:
|
| 764 |
-
|
|
|
|
| 765 |
answer = _fallback_qa(question, context)
|
| 766 |
guardrail = None
|
| 767 |
docs_relevant = 0
|
|
@@ -770,7 +779,8 @@ def streaming_answer(question: str, context: str = ""):
|
|
| 770 |
if not answer:
|
| 771 |
answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
|
| 772 |
|
| 773 |
-
|
|
|
|
| 774 |
|
| 775 |
elapsed = time.time() - start_time
|
| 776 |
|
|
@@ -779,9 +789,10 @@ def streaming_answer(question: str, context: str = ""):
|
|
| 779 |
accumulated = ""
|
| 780 |
for i, word in enumerate(words):
|
| 781 |
accumulated += word + " "
|
| 782 |
-
if i %
|
| 783 |
-
|
| 784 |
-
|
|
|
|
| 785 |
|
| 786 |
# Final response with metadata
|
| 787 |
meta_parts = [f"⏱️ {elapsed:.1f}s"]
|
|
@@ -792,15 +803,34 @@ def streaming_answer(question: str, context: str = ""):
|
|
| 792 |
meta_parts.append("🤖 Agentic RAG" if rag_service else "🤖 RAG")
|
| 793 |
meta_line = " | ".join(meta_parts)
|
| 794 |
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
*{meta_line}*
|
| 799 |
-
"""
|
| 800 |
|
| 801 |
except Exception as exc:
|
| 802 |
logger.exception(f"Streaming Q&A error: {exc}")
|
| 803 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
|
| 805 |
|
| 806 |
# ---------------------------------------------------------------------------
|
|
@@ -1039,7 +1069,7 @@ footer { display: none !important; }
|
|
| 1039 |
|
| 1040 |
def create_demo() -> gr.Blocks:
|
| 1041 |
"""Create the Gradio Blocks interface with modern medical UI."""
|
| 1042 |
-
|
| 1043 |
with gr.Blocks(
|
| 1044 |
title="Agentic RagBot - Medical Biomarker Analysis",
|
| 1045 |
theme=gr.themes.Soft(
|
|
@@ -1065,7 +1095,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1065 |
),
|
| 1066 |
css=CUSTOM_CSS,
|
| 1067 |
) as demo:
|
| 1068 |
-
|
| 1069 |
# ===== HEADER =====
|
| 1070 |
gr.HTML("""
|
| 1071 |
<div class="header-container">
|
|
@@ -1079,7 +1109,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1079 |
</div>
|
| 1080 |
</div>
|
| 1081 |
""")
|
| 1082 |
-
|
| 1083 |
# ===== API KEY INFO =====
|
| 1084 |
gr.HTML("""
|
| 1085 |
<div class="info-banner">
|
|
@@ -1096,20 +1126,20 @@ def create_demo() -> gr.Blocks:
|
|
| 1096 |
</div>
|
| 1097 |
</div>
|
| 1098 |
""")
|
| 1099 |
-
|
| 1100 |
# ===== MAIN TABS =====
|
| 1101 |
with gr.Tabs() as main_tabs:
|
| 1102 |
-
|
| 1103 |
# ==================== TAB 1: BIOMARKER ANALYSIS ====================
|
| 1104 |
with gr.Tab("🔬 Biomarker Analysis", id="biomarker-tab"):
|
| 1105 |
-
|
| 1106 |
# ===== MAIN CONTENT =====
|
| 1107 |
with gr.Row(equal_height=False):
|
| 1108 |
-
|
| 1109 |
# ----- LEFT PANEL: INPUT -----
|
| 1110 |
with gr.Column(scale=2, min_width=400):
|
| 1111 |
gr.HTML('<div class="section-title">📝 Enter Your Biomarkers</div>')
|
| 1112 |
-
|
| 1113 |
with gr.Group():
|
| 1114 |
input_text = gr.Textbox(
|
| 1115 |
label="",
|
|
@@ -1118,31 +1148,31 @@ def create_demo() -> gr.Blocks:
|
|
| 1118 |
max_lines=12,
|
| 1119 |
show_label=False,
|
| 1120 |
)
|
| 1121 |
-
|
| 1122 |
with gr.Row():
|
| 1123 |
analyze_btn = gr.Button(
|
| 1124 |
-
"🔬 Analyze Biomarkers",
|
| 1125 |
-
variant="primary",
|
| 1126 |
size="lg",
|
| 1127 |
scale=3,
|
| 1128 |
)
|
| 1129 |
clear_btn = gr.Button(
|
| 1130 |
-
"🗑️ Clear",
|
| 1131 |
variant="secondary",
|
| 1132 |
size="lg",
|
| 1133 |
scale=1,
|
| 1134 |
)
|
| 1135 |
-
|
| 1136 |
# Status display
|
| 1137 |
status_output = gr.Markdown(
|
| 1138 |
value="",
|
| 1139 |
elem_classes="status-box"
|
| 1140 |
)
|
| 1141 |
-
|
| 1142 |
# Quick Examples
|
| 1143 |
gr.HTML('<div class="section-title" style="margin-top: 24px;">⚡ Quick Examples</div>')
|
| 1144 |
gr.HTML('<p style="color: #64748b; font-size: 0.9em; margin-bottom: 12px;">Click any example to load it instantly</p>')
|
| 1145 |
-
|
| 1146 |
examples = gr.Examples(
|
| 1147 |
examples=[
|
| 1148 |
["Glucose: 185, HbA1c: 8.2, Cholesterol: 245, LDL: 165"],
|
|
@@ -1154,7 +1184,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1154 |
inputs=input_text,
|
| 1155 |
label="",
|
| 1156 |
)
|
| 1157 |
-
|
| 1158 |
# Supported Biomarkers
|
| 1159 |
with gr.Accordion("📊 Supported Biomarkers", open=False):
|
| 1160 |
gr.HTML("""
|
|
@@ -1185,11 +1215,11 @@ def create_demo() -> gr.Blocks:
|
|
| 1185 |
</div>
|
| 1186 |
</div>
|
| 1187 |
""")
|
| 1188 |
-
|
| 1189 |
# ----- RIGHT PANEL: RESULTS -----
|
| 1190 |
with gr.Column(scale=3, min_width=500):
|
| 1191 |
gr.HTML('<div class="section-title">📊 Analysis Results</div>')
|
| 1192 |
-
|
| 1193 |
with gr.Tabs() as result_tabs:
|
| 1194 |
with gr.Tab("📋 Summary", id="summary"):
|
| 1195 |
summary_output = gr.Markdown(
|
|
@@ -1202,7 +1232,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1202 |
""",
|
| 1203 |
elem_classes="summary-output"
|
| 1204 |
)
|
| 1205 |
-
|
| 1206 |
with gr.Tab("🔍 Detailed JSON", id="json"):
|
| 1207 |
details_output = gr.Code(
|
| 1208 |
label="",
|
|
@@ -1210,10 +1240,10 @@ def create_demo() -> gr.Blocks:
|
|
| 1210 |
lines=30,
|
| 1211 |
show_label=False,
|
| 1212 |
)
|
| 1213 |
-
|
| 1214 |
# ==================== TAB 2: MEDICAL Q&A ====================
|
| 1215 |
with gr.Tab("💬 Medical Q&A", id="qa-tab"):
|
| 1216 |
-
|
| 1217 |
gr.HTML("""
|
| 1218 |
<div style="margin-bottom: 20px;">
|
| 1219 |
<h3 style="color: #1e3a5f; margin: 0 0 8px 0;">💬 Medical Q&A Assistant</h3>
|
|
@@ -1222,7 +1252,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1222 |
</p>
|
| 1223 |
</div>
|
| 1224 |
""")
|
| 1225 |
-
|
| 1226 |
with gr.Row(equal_height=False):
|
| 1227 |
with gr.Column(scale=1):
|
| 1228 |
qa_context = gr.Textbox(
|
|
@@ -1231,6 +1261,11 @@ def create_demo() -> gr.Blocks:
|
|
| 1231 |
lines=3,
|
| 1232 |
max_lines=6,
|
| 1233 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1234 |
qa_question = gr.Textbox(
|
| 1235 |
label="Your Question",
|
| 1236 |
placeholder="Ask any medical question...\n• What do my elevated glucose levels indicate?\n• Should I be concerned about my HbA1c of 7.5%?\n• What lifestyle changes help with prediabetes?",
|
|
@@ -1246,11 +1281,11 @@ def create_demo() -> gr.Blocks:
|
|
| 1246 |
)
|
| 1247 |
qa_clear_btn = gr.Button(
|
| 1248 |
"🗑️ Clear",
|
| 1249 |
-
variant="secondary",
|
| 1250 |
size="lg",
|
| 1251 |
scale=1,
|
| 1252 |
)
|
| 1253 |
-
|
| 1254 |
# Quick question examples
|
| 1255 |
gr.HTML('<h4 style="margin-top: 16px; color: #1e3a5f;">Example Questions</h4>')
|
| 1256 |
qa_examples = gr.Examples(
|
|
@@ -1263,42 +1298,54 @@ def create_demo() -> gr.Blocks:
|
|
| 1263 |
inputs=[qa_question, qa_context],
|
| 1264 |
label="",
|
| 1265 |
)
|
| 1266 |
-
|
| 1267 |
with gr.Column(scale=2):
|
| 1268 |
gr.HTML('<h4 style="color: #1e3a5f; margin-bottom: 12px;">📝 Answer</h4>')
|
| 1269 |
-
qa_answer = gr.
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
<div style="font-size: 3em; margin-bottom: 12px;">💬</div>
|
| 1273 |
-
<h3 style="color: #64748b; font-weight: 500;">Ask a Medical Question</h3>
|
| 1274 |
-
<p>Enter your question on the left and click <strong>Ask Question</strong> to get evidence-based answers.</p>
|
| 1275 |
-
</div>
|
| 1276 |
-
""",
|
| 1277 |
elem_classes="qa-output"
|
| 1278 |
)
|
| 1279 |
-
|
| 1280 |
# Q&A Event Handlers
|
| 1281 |
qa_submit_btn.click(
|
| 1282 |
fn=streaming_answer,
|
| 1283 |
-
inputs=[qa_question, qa_context],
|
| 1284 |
outputs=qa_answer,
|
| 1285 |
show_progress="minimal",
|
|
|
|
|
|
|
|
|
|
| 1286 |
)
|
| 1287 |
-
|
| 1288 |
qa_clear_btn.click(
|
| 1289 |
-
fn=lambda: (
|
| 1290 |
-
|
| 1291 |
-
<div style="font-size: 3em; margin-bottom: 12px;">💬</div>
|
| 1292 |
-
<h3 style="color: #64748b; font-weight: 500;">Ask a Medical Question</h3>
|
| 1293 |
-
<p>Enter your question on the left and click <strong>Ask Question</strong> to get evidence-based answers.</p>
|
| 1294 |
-
</div>
|
| 1295 |
-
"""),
|
| 1296 |
-
outputs=[qa_question, qa_context, qa_answer],
|
| 1297 |
)
|
| 1298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1299 |
# ===== HOW IT WORKS =====
|
| 1300 |
gr.HTML('<div class="section-title" style="margin-top: 32px;">🤖 How It Works</div>')
|
| 1301 |
-
|
| 1302 |
gr.HTML("""
|
| 1303 |
<div class="agent-grid">
|
| 1304 |
<div class="agent-card">
|
|
@@ -1327,7 +1374,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1327 |
</div>
|
| 1328 |
</div>
|
| 1329 |
""")
|
| 1330 |
-
|
| 1331 |
# ===== DISCLAIMER =====
|
| 1332 |
gr.HTML("""
|
| 1333 |
<div class="disclaimer">
|
|
@@ -1337,7 +1384,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1337 |
clinical guidelines and may not account for your specific medical history.
|
| 1338 |
</div>
|
| 1339 |
""")
|
| 1340 |
-
|
| 1341 |
# ===== FOOTER =====
|
| 1342 |
gr.HTML("""
|
| 1343 |
<div style="text-align: center; padding: 24px; color: #94a3b8; font-size: 0.85em; margin-top: 24px;">
|
|
@@ -1352,7 +1399,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1352 |
</p>
|
| 1353 |
</div>
|
| 1354 |
""")
|
| 1355 |
-
|
| 1356 |
# ===== EVENT HANDLERS =====
|
| 1357 |
analyze_btn.click(
|
| 1358 |
fn=analyze_biomarkers,
|
|
@@ -1360,7 +1407,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1360 |
outputs=[summary_output, details_output, status_output],
|
| 1361 |
show_progress="full",
|
| 1362 |
)
|
| 1363 |
-
|
| 1364 |
clear_btn.click(
|
| 1365 |
fn=lambda: ("", """
|
| 1366 |
<div style="text-align: center; padding: 60px 20px; color: #94a3b8;">
|
|
@@ -1371,7 +1418,7 @@ def create_demo() -> gr.Blocks:
|
|
| 1371 |
""", "", ""),
|
| 1372 |
outputs=[input_text, summary_output, details_output, status_output],
|
| 1373 |
)
|
| 1374 |
-
|
| 1375 |
return demo
|
| 1376 |
|
| 1377 |
|
|
@@ -1381,9 +1428,9 @@ def create_demo() -> gr.Blocks:
|
|
| 1381 |
|
| 1382 |
if __name__ == "__main__":
|
| 1383 |
logger.info("Starting MediGuard AI Gradio App...")
|
| 1384 |
-
|
| 1385 |
demo = create_demo()
|
| 1386 |
-
|
| 1387 |
# Launch with HF Spaces compatible settings
|
| 1388 |
demo.launch(
|
| 1389 |
server_name="0.0.0.0",
|
|
|
|
| 37 |
import time
|
| 38 |
import traceback
|
| 39 |
from pathlib import Path
|
| 40 |
+
from typing import Any
|
| 41 |
|
| 42 |
# Ensure project root is in path
|
| 43 |
_project_root = str(Path(__file__).parent.parent)
|
|
|
|
| 114 |
"""
|
| 115 |
groq_key, google_key = get_api_keys()
|
| 116 |
provider = None
|
| 117 |
+
|
| 118 |
if groq_key:
|
| 119 |
os.environ["LLM_PROVIDER"] = "groq"
|
| 120 |
os.environ["GROQ_API_KEY"] = groq_key
|
|
|
|
| 127 |
os.environ["GEMINI_MODEL"] = get_gemini_model()
|
| 128 |
provider = "gemini"
|
| 129 |
logger.info(f"Configured Gemini provider with model: {get_gemini_model()}")
|
| 130 |
+
|
| 131 |
# Set up embedding provider
|
| 132 |
embedding_provider = get_embedding_provider()
|
| 133 |
os.environ["EMBEDDING_PROVIDER"] = embedding_provider
|
| 134 |
+
|
| 135 |
# If Jina is configured, set the API key
|
| 136 |
jina_key = get_jina_api_key()
|
| 137 |
if jina_key:
|
| 138 |
os.environ["JINA_API_KEY"] = jina_key
|
| 139 |
os.environ["EMBEDDING__JINA_API_KEY"] = jina_key
|
| 140 |
logger.info("Jina embeddings configured")
|
| 141 |
+
|
| 142 |
# Set up Langfuse if enabled
|
| 143 |
if is_langfuse_enabled():
|
| 144 |
os.environ["LANGFUSE__ENABLED"] = "true"
|
|
|
|
| 147 |
if val:
|
| 148 |
os.environ[var] = val
|
| 149 |
logger.info("Langfuse observability enabled")
|
| 150 |
+
|
| 151 |
return provider
|
| 152 |
|
| 153 |
|
|
|
|
| 192 |
def get_guild():
|
| 193 |
"""Lazy initialization of the Clinical Insight Guild."""
|
| 194 |
global _guild, _guild_error, _guild_provider
|
| 195 |
+
|
| 196 |
# Check if we need to reinitialize (provider changed)
|
| 197 |
current_provider = os.getenv("LLM_PROVIDER")
|
| 198 |
if _guild_provider and _guild_provider != current_provider:
|
| 199 |
logger.info(f"Provider changed from {_guild_provider} to {current_provider}, reinitializing...")
|
| 200 |
reset_guild()
|
| 201 |
+
|
| 202 |
if _guild is not None:
|
| 203 |
return _guild
|
| 204 |
+
|
| 205 |
if _guild_error is not None:
|
| 206 |
# Don't cache errors forever - allow retry
|
| 207 |
logger.warning("Previous initialization failed, retrying...")
|
| 208 |
_guild_error = None
|
| 209 |
+
|
| 210 |
try:
|
| 211 |
logger.info("Initializing Clinical Insight Guild...")
|
| 212 |
logger.info(f" LLM_PROVIDER: {os.getenv('LLM_PROVIDER', 'not set')}")
|
|
|
|
| 214 |
logger.info(f" GOOGLE_API_KEY: {'✓ set' if os.getenv('GOOGLE_API_KEY') else '✗ not set'}")
|
| 215 |
logger.info(f" EMBEDDING_PROVIDER: {os.getenv('EMBEDDING_PROVIDER', 'huggingface')}")
|
| 216 |
logger.info(f" JINA_API_KEY: {'✓ set' if os.getenv('JINA_API_KEY') else '✗ not set'}")
|
| 217 |
+
|
| 218 |
start = time.time()
|
| 219 |
+
|
| 220 |
from src.workflow import create_guild
|
| 221 |
_guild = create_guild()
|
| 222 |
_guild_provider = current_provider
|
| 223 |
+
|
| 224 |
elapsed = time.time() - start
|
| 225 |
logger.info(f"Guild initialized in {elapsed:.1f}s")
|
| 226 |
return _guild
|
| 227 |
+
|
| 228 |
except Exception as exc:
|
| 229 |
logger.error(f"Failed to initialize guild: {exc}")
|
| 230 |
_guild_error = exc
|
|
|
|
| 237 |
|
| 238 |
# Import shared parsing and prediction logic
|
| 239 |
from src.shared_utils import (
|
|
|
|
| 240 |
get_primary_prediction,
|
| 241 |
+
parse_biomarkers,
|
|
|
|
|
|
|
| 242 |
)
|
| 243 |
|
| 244 |
|
|
|
|
| 264 |
<p style="margin: 8px 0 0 0; color: #64748b;">Please enter biomarkers to analyze.</p>
|
| 265 |
</div>
|
| 266 |
"""
|
| 267 |
+
|
| 268 |
# Check API key dynamically (HF injects secrets after startup)
|
| 269 |
groq_key, google_key = get_api_keys()
|
| 270 |
+
|
| 271 |
if not groq_key and not google_key:
|
| 272 |
return "", "", """
|
| 273 |
<div style="background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); border: 1px solid #ef4444; border-radius: 10px; padding: 16px;">
|
|
|
|
| 294 |
</details>
|
| 295 |
</div>
|
| 296 |
"""
|
| 297 |
+
|
| 298 |
# Setup provider based on available key
|
| 299 |
provider = setup_llm_provider()
|
| 300 |
logger.info(f"Using LLM provider: {provider}")
|
| 301 |
+
|
| 302 |
try:
|
| 303 |
progress(0.1, desc="📝 Parsing biomarkers...")
|
| 304 |
biomarkers = parse_biomarkers(input_text)
|
| 305 |
+
|
| 306 |
if not biomarkers:
|
| 307 |
return "", "", """
|
| 308 |
<div style="background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); border: 1px solid #fbbf24; border-radius: 10px; padding: 16px;">
|
|
|
|
| 314 |
</ul>
|
| 315 |
</div>
|
| 316 |
"""
|
| 317 |
+
|
| 318 |
progress(0.2, desc="🔧 Initializing AI agents...")
|
| 319 |
+
|
| 320 |
# Initialize guild
|
| 321 |
guild = get_guild()
|
| 322 |
+
|
| 323 |
# Prepare input
|
| 324 |
from src.state import PatientInput
|
| 325 |
+
|
| 326 |
# Auto-generate prediction based on common patterns
|
| 327 |
prediction = auto_predict(biomarkers)
|
| 328 |
+
|
| 329 |
patient_input = PatientInput(
|
| 330 |
biomarkers=biomarkers,
|
| 331 |
model_prediction=prediction,
|
| 332 |
patient_context={"patient_id": "HF_User", "source": "huggingface_spaces"}
|
| 333 |
)
|
| 334 |
+
|
| 335 |
progress(0.4, desc="🤖 Running Clinical Insight Guild...")
|
| 336 |
+
|
| 337 |
# Run analysis
|
| 338 |
start = time.time()
|
| 339 |
result = guild.run(patient_input)
|
| 340 |
elapsed = time.time() - start
|
| 341 |
+
|
| 342 |
progress(0.9, desc="✨ Formatting results...")
|
| 343 |
+
|
| 344 |
# Extract response
|
| 345 |
final_response = result.get("final_response", {})
|
| 346 |
+
|
| 347 |
# Format summary
|
| 348 |
summary = format_summary(final_response, elapsed)
|
| 349 |
+
|
| 350 |
# Format details
|
| 351 |
details = json.dumps(final_response, indent=2, default=str)
|
| 352 |
+
|
| 353 |
status = f"""
|
| 354 |
<div style="background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%); border: 1px solid #10b981; border-radius: 10px; padding: 12px; display: flex; align-items: center; gap: 10px;">
|
| 355 |
<span style="font-size: 1.5em;">✅</span>
|
|
|
|
| 359 |
</div>
|
| 360 |
</div>
|
| 361 |
"""
|
| 362 |
+
|
| 363 |
return summary, details, status
|
| 364 |
+
|
| 365 |
except Exception as exc:
|
| 366 |
logger.error(f"Analysis error: {exc}", exc_info=True)
|
| 367 |
error_msg = f"""
|
|
|
|
| 381 |
"""Format the analysis response as clean markdown with black text."""
|
| 382 |
if not response:
|
| 383 |
return "❌ **No analysis results available.**"
|
| 384 |
+
|
| 385 |
parts = []
|
| 386 |
+
|
| 387 |
# Header with primary finding and confidence
|
| 388 |
primary = response.get("primary_finding", "Analysis Complete")
|
| 389 |
confidence = response.get("confidence", {})
|
| 390 |
conf_score = confidence.get("overall_score", 0) if isinstance(confidence, dict) else 0
|
| 391 |
+
|
| 392 |
# Determine severity
|
| 393 |
severity = response.get("severity", "low")
|
| 394 |
severity_config = {
|
|
|
|
| 398 |
"low": ("🟢", "#16a34a", "#f0fdf4")
|
| 399 |
}
|
| 400 |
emoji, color, bg_color = severity_config.get(severity, severity_config["low"])
|
| 401 |
+
|
| 402 |
# Build confidence display
|
| 403 |
conf_badge = ""
|
| 404 |
if conf_score:
|
| 405 |
conf_pct = int(conf_score * 100)
|
| 406 |
conf_color = "#16a34a" if conf_pct >= 80 else "#ca8a04" if conf_pct >= 60 else "#dc2626"
|
| 407 |
conf_badge = f'<span style="background: {conf_color}; color: white; padding: 4px 12px; border-radius: 20px; font-size: 0.85em; margin-left: 12px;">{conf_pct}% confidence</span>'
|
| 408 |
+
|
| 409 |
parts.append(f"""
|
| 410 |
<div style="background: linear-gradient(135deg, {bg_color} 0%, white 100%); border-left: 4px solid {color}; border-radius: 12px; padding: 20px; margin-bottom: 20px;">
|
| 411 |
<div style="display: flex; align-items: center; flex-wrap: wrap;">
|
|
|
|
| 414 |
{conf_badge}
|
| 415 |
</div>
|
| 416 |
</div>""")
|
| 417 |
+
|
| 418 |
# Critical Alerts
|
| 419 |
alerts = response.get("safety_alerts", [])
|
| 420 |
if alerts:
|
|
|
|
| 424 |
alert_items += f'<li><strong>{alert.get("alert_type", "Alert")}:</strong> {alert.get("message", "")}</li>'
|
| 425 |
else:
|
| 426 |
alert_items += f'<li>{alert}</li>'
|
| 427 |
+
|
| 428 |
parts.append(f"""
|
| 429 |
<div style="background: linear-gradient(135deg, #fef2f2 0%, #fee2e2 100%); border: 1px solid #fecaca; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
| 430 |
<h4 style="margin: 0 0 12px 0; color: #dc2626; display: flex; align-items: center; gap: 8px;">
|
|
|
|
| 433 |
<ul style="margin: 0; padding-left: 20px; color: #991b1b;">{alert_items}</ul>
|
| 434 |
</div>
|
| 435 |
""")
|
| 436 |
+
|
| 437 |
# Key Findings
|
| 438 |
findings = response.get("key_findings", [])
|
| 439 |
if findings:
|
|
|
|
| 444 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{finding_items}</ul>
|
| 445 |
</div>
|
| 446 |
""")
|
| 447 |
+
|
| 448 |
# Biomarker Flags - as a visual grid
|
| 449 |
flags = response.get("biomarker_flags", [])
|
| 450 |
if flags and len(flags) > 0:
|
|
|
|
| 457 |
continue
|
| 458 |
status = flag.get("status", "normal").lower()
|
| 459 |
value = flag.get("value", flag.get("result", "N/A"))
|
| 460 |
+
|
| 461 |
status_styles = {
|
| 462 |
"critical": ("🔴", "#dc2626", "#fef2f2"),
|
| 463 |
"high": ("🔴", "#dc2626", "#fef2f2"),
|
|
|
|
| 466 |
"normal": ("🟢", "#16a34a", "#f0fdf4")
|
| 467 |
}
|
| 468 |
s_emoji, s_color, s_bg = status_styles.get(status, status_styles["normal"])
|
| 469 |
+
|
| 470 |
flag_cards += f"""
|
| 471 |
<div style="background: {s_bg}; border: 1px solid {s_color}33; border-radius: 8px; padding: 12px; text-align: center;">
|
| 472 |
<div style="font-size: 1.2em;">{s_emoji}</div>
|
|
|
|
| 475 |
<div style="font-size: 0.75em; color: #64748b; text-transform: capitalize;">{status}</div>
|
| 476 |
</div>
|
| 477 |
"""
|
| 478 |
+
|
| 479 |
if flag_cards: # Only show section if we have cards
|
| 480 |
parts.append(f"""
|
| 481 |
<div style="margin-bottom: 16px;">
|
|
|
|
| 485 |
</div>
|
| 486 |
</div>
|
| 487 |
""")
|
| 488 |
+
|
| 489 |
# Recommendations - organized sections
|
| 490 |
recs = response.get("recommendations", {})
|
| 491 |
rec_sections = ""
|
| 492 |
+
|
| 493 |
immediate = recs.get("immediate_actions", []) if isinstance(recs, dict) else []
|
| 494 |
if immediate and len(immediate) > 0:
|
| 495 |
items = "".join([f'<li style="margin-bottom: 6px;">{str(a).strip()}</li>' for a in immediate[:3]])
|
|
|
|
| 499 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
|
| 500 |
</div>
|
| 501 |
"""
|
| 502 |
+
|
| 503 |
lifestyle = recs.get("lifestyle_modifications", []) if isinstance(recs, dict) else []
|
| 504 |
if lifestyle and len(lifestyle) > 0:
|
| 505 |
items = "".join([f'<li style="margin-bottom: 6px;">{str(m).strip()}</li>' for m in lifestyle[:3]])
|
|
|
|
| 509 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
|
| 510 |
</div>
|
| 511 |
"""
|
| 512 |
+
|
| 513 |
followup = recs.get("follow_up", []) if isinstance(recs, dict) else []
|
| 514 |
if followup and len(followup) > 0:
|
| 515 |
items = "".join([f'<li style="margin-bottom: 6px;">{str(f).strip()}</li>' for f in followup[:3]])
|
|
|
|
| 519 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
|
| 520 |
</div>
|
| 521 |
"""
|
| 522 |
+
|
| 523 |
# Add default recommendations if none provided
|
| 524 |
if not rec_sections:
|
| 525 |
+
rec_sections = """
|
| 526 |
<div style="margin-bottom: 12px;">
|
| 527 |
<h5 style="margin: 0 0 8px 0; color: #2563eb;">📋 General Recommendations</h5>
|
| 528 |
<ul style="margin: 0; padding-left: 20px; color: #475569;">
|
|
|
|
| 532 |
</ul>
|
| 533 |
</div>
|
| 534 |
"""
|
| 535 |
+
|
| 536 |
if rec_sections:
|
| 537 |
parts.append(f"""
|
| 538 |
<div style="background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); border-radius: 12px; padding: 16px; margin-bottom: 16px;">
|
|
|
|
| 540 |
{rec_sections}
|
| 541 |
</div>
|
| 542 |
""")
|
| 543 |
+
|
| 544 |
# Disease Explanation
|
| 545 |
explanation = response.get("disease_explanation", {})
|
| 546 |
if explanation and isinstance(explanation, dict):
|
|
|
|
| 552 |
<p style="margin: 0; color: #475569; line-height: 1.6;">{pathophys[:600]}{'...' if len(pathophys) > 600 else ''}</p>
|
| 553 |
</div>
|
| 554 |
""")
|
| 555 |
+
|
| 556 |
# Conversational Summary
|
| 557 |
conv_summary = response.get("conversational_summary", "")
|
| 558 |
if conv_summary:
|
|
|
|
| 562 |
<p style="margin: 0; color: #475569; line-height: 1.6;">{conv_summary[:1000]}</p>
|
| 563 |
</div>
|
| 564 |
""")
|
| 565 |
+
|
| 566 |
# Footer
|
| 567 |
parts.append(f"""
|
| 568 |
<div style="border-top: 1px solid #e2e8f0; padding-top: 16px; margin-top: 8px; text-align: center;">
|
|
|
|
| 574 |
</p>
|
| 575 |
</div>
|
| 576 |
""")
|
| 577 |
+
|
| 578 |
return "\n".join(parts)
|
| 579 |
|
| 580 |
|
|
|
|
| 603 |
_rag_service_error = None
|
| 604 |
|
| 605 |
try:
|
| 606 |
+
from src.llm_config import get_synthesizer
|
| 607 |
from src.services.agents.agentic_rag import AgenticRAGService
|
| 608 |
from src.services.agents.context import AgenticContext
|
| 609 |
from src.services.retrieval.factory import make_retriever
|
|
|
|
| 610 |
|
| 611 |
llm = get_synthesizer()
|
| 612 |
retriever = make_retriever() # auto-detects FAISS
|
|
|
|
| 634 |
|
| 635 |
def _fallback_qa(question: str, context_text: str = "") -> str:
|
| 636 |
"""Direct retriever+LLM fallback when agentic pipeline is unavailable."""
|
|
|
|
| 637 |
from src.llm_config import get_synthesizer
|
| 638 |
+
from src.services.retrieval.factory import make_retriever
|
| 639 |
|
| 640 |
retriever = make_retriever()
|
| 641 |
search_query = f"{context_text} {question}" if context_text.strip() else question
|
|
|
|
| 724 |
|
| 725 |
except Exception as exc:
|
| 726 |
logger.exception(f"Q&A error: {exc}")
|
| 727 |
+
error_msg = f"❌ Error: {exc!s}"
|
| 728 |
history = (chat_history or []) + [(question, error_msg)]
|
| 729 |
return error_msg, history
|
| 730 |
|
| 731 |
|
| 732 |
+
def streaming_answer(question: str, context: str, history: list, model: str):
|
| 733 |
"""Stream answer using the full agentic RAG pipeline.
|
| 734 |
Falls back to direct retriever+LLM if the pipeline is unavailable.
|
| 735 |
"""
|
| 736 |
+
history = history or []
|
| 737 |
if not question.strip():
|
| 738 |
+
yield history
|
| 739 |
return
|
| 740 |
|
| 741 |
+
history.append((question, ""))
|
| 742 |
+
|
| 743 |
if not groq_key and not google_key:
|
| 744 |
+
history[-1] = (question, "❌ Please add your GROQ_API_KEY or GOOGLE_API_KEY in Space Settings → Secrets.")
|
| 745 |
+
yield history
|
| 746 |
return
|
| 747 |
|
| 748 |
+
# Update provider if model changed (simplified handling for UI demo)
|
| 749 |
+
if "gemini" in model.lower():
|
| 750 |
+
os.environ["LLM_PROVIDER"] = "gemini"
|
| 751 |
+
else:
|
| 752 |
+
os.environ["LLM_PROVIDER"] = "groq"
|
| 753 |
+
|
| 754 |
setup_llm_provider()
|
| 755 |
|
| 756 |
try:
|
| 757 |
+
history[-1] = (question, "🛡️ Checking medical domain relevance...\n\n")
|
| 758 |
+
yield history
|
| 759 |
|
| 760 |
start_time = time.time()
|
| 761 |
|
| 762 |
rag_service = _get_rag_service()
|
| 763 |
if rag_service is not None:
|
| 764 |
+
history[-1] = (question, "🛡️ Checking medical domain relevance...\n🔍 Retrieving medical documents...\n\n")
|
| 765 |
+
yield history
|
| 766 |
result = rag_service.ask(query=question, patient_context=context)
|
| 767 |
answer = result.get("final_answer", "")
|
| 768 |
guardrail = result.get("guardrail_score")
|
| 769 |
docs_relevant = len(result.get("relevant_documents", []))
|
| 770 |
docs_retrieved = len(result.get("retrieved_documents", []))
|
| 771 |
else:
|
| 772 |
+
history[-1] = (question, "🔍 Searching medical knowledge base...\n📚 Retrieving relevant documents...\n\n")
|
| 773 |
+
yield history
|
| 774 |
answer = _fallback_qa(question, context)
|
| 775 |
guardrail = None
|
| 776 |
docs_relevant = 0
|
|
|
|
| 779 |
if not answer:
|
| 780 |
answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
|
| 781 |
|
| 782 |
+
history[-1] = (question, "🛡️ Guardrail ✓\n🔍 Retrieved ✓\n📊 Graded ✓\n💭 Generating response...\n\n")
|
| 783 |
+
yield history
|
| 784 |
|
| 785 |
elapsed = time.time() - start_time
|
| 786 |
|
|
|
|
| 789 |
accumulated = ""
|
| 790 |
for i, word in enumerate(words):
|
| 791 |
accumulated += word + " "
|
| 792 |
+
if i % 10 == 0:
|
| 793 |
+
history[-1] = (question, accumulated)
|
| 794 |
+
yield history
|
| 795 |
+
time.sleep(0.01)
|
| 796 |
|
| 797 |
# Final response with metadata
|
| 798 |
meta_parts = [f"⏱️ {elapsed:.1f}s"]
|
|
|
|
| 803 |
meta_parts.append("🤖 Agentic RAG" if rag_service else "🤖 RAG")
|
| 804 |
meta_line = " | ".join(meta_parts)
|
| 805 |
|
| 806 |
+
final_msg = f"{answer}\n\n---\n*{meta_line}*\n"
|
| 807 |
+
history[-1] = (question, final_msg)
|
| 808 |
+
yield history
|
|
|
|
|
|
|
| 809 |
|
| 810 |
except Exception as exc:
|
| 811 |
logger.exception(f"Streaming Q&A error: {exc}")
|
| 812 |
+
history[-1] = (question, f"❌ Error: {exc!s}")
|
| 813 |
+
yield history
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def hf_search(query: str, mode: str):
|
| 817 |
+
"""Direct fast-retrieval for the HF Space Knowledge tab."""
|
| 818 |
+
if not query.strip():
|
| 819 |
+
return "Please enter a query."
|
| 820 |
+
try:
|
| 821 |
+
from src.services.retrieval.factory import make_retriever
|
| 822 |
+
retriever = make_retriever()
|
| 823 |
+
docs = retriever.retrieve(query, top_k=5)
|
| 824 |
+
if not docs:
|
| 825 |
+
return "No results found."
|
| 826 |
+
parts = []
|
| 827 |
+
for i, doc in enumerate(docs, 1):
|
| 828 |
+
title = doc.metadata.get("title", doc.metadata.get("source_file", "Untitled"))
|
| 829 |
+
score = doc.score if hasattr(doc, 'score') else 0.0
|
| 830 |
+
parts.append(f"**[{i}] {title}** (score: {score:.3f})\n{doc.content}\n")
|
| 831 |
+
return "\n---\n".join(parts)
|
| 832 |
+
except Exception as exc:
|
| 833 |
+
return f"Error: {exc}"
|
| 834 |
|
| 835 |
|
| 836 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1069 |
|
| 1070 |
def create_demo() -> gr.Blocks:
|
| 1071 |
"""Create the Gradio Blocks interface with modern medical UI."""
|
| 1072 |
+
|
| 1073 |
with gr.Blocks(
|
| 1074 |
title="Agentic RagBot - Medical Biomarker Analysis",
|
| 1075 |
theme=gr.themes.Soft(
|
|
|
|
| 1095 |
),
|
| 1096 |
css=CUSTOM_CSS,
|
| 1097 |
) as demo:
|
| 1098 |
+
|
| 1099 |
# ===== HEADER =====
|
| 1100 |
gr.HTML("""
|
| 1101 |
<div class="header-container">
|
|
|
|
| 1109 |
</div>
|
| 1110 |
</div>
|
| 1111 |
""")
|
| 1112 |
+
|
| 1113 |
# ===== API KEY INFO =====
|
| 1114 |
gr.HTML("""
|
| 1115 |
<div class="info-banner">
|
|
|
|
| 1126 |
</div>
|
| 1127 |
</div>
|
| 1128 |
""")
|
| 1129 |
+
|
| 1130 |
# ===== MAIN TABS =====
|
| 1131 |
with gr.Tabs() as main_tabs:
|
| 1132 |
+
|
| 1133 |
# ==================== TAB 1: BIOMARKER ANALYSIS ====================
|
| 1134 |
with gr.Tab("🔬 Biomarker Analysis", id="biomarker-tab"):
|
| 1135 |
+
|
| 1136 |
# ===== MAIN CONTENT =====
|
| 1137 |
with gr.Row(equal_height=False):
|
| 1138 |
+
|
| 1139 |
# ----- LEFT PANEL: INPUT -----
|
| 1140 |
with gr.Column(scale=2, min_width=400):
|
| 1141 |
gr.HTML('<div class="section-title">📝 Enter Your Biomarkers</div>')
|
| 1142 |
+
|
| 1143 |
with gr.Group():
|
| 1144 |
input_text = gr.Textbox(
|
| 1145 |
label="",
|
|
|
|
| 1148 |
max_lines=12,
|
| 1149 |
show_label=False,
|
| 1150 |
)
|
| 1151 |
+
|
| 1152 |
with gr.Row():
|
| 1153 |
analyze_btn = gr.Button(
|
| 1154 |
+
"🔬 Analyze Biomarkers",
|
| 1155 |
+
variant="primary",
|
| 1156 |
size="lg",
|
| 1157 |
scale=3,
|
| 1158 |
)
|
| 1159 |
clear_btn = gr.Button(
|
| 1160 |
+
"🗑️ Clear",
|
| 1161 |
variant="secondary",
|
| 1162 |
size="lg",
|
| 1163 |
scale=1,
|
| 1164 |
)
|
| 1165 |
+
|
| 1166 |
# Status display
|
| 1167 |
status_output = gr.Markdown(
|
| 1168 |
value="",
|
| 1169 |
elem_classes="status-box"
|
| 1170 |
)
|
| 1171 |
+
|
| 1172 |
# Quick Examples
|
| 1173 |
gr.HTML('<div class="section-title" style="margin-top: 24px;">⚡ Quick Examples</div>')
|
| 1174 |
gr.HTML('<p style="color: #64748b; font-size: 0.9em; margin-bottom: 12px;">Click any example to load it instantly</p>')
|
| 1175 |
+
|
| 1176 |
examples = gr.Examples(
|
| 1177 |
examples=[
|
| 1178 |
["Glucose: 185, HbA1c: 8.2, Cholesterol: 245, LDL: 165"],
|
|
|
|
| 1184 |
inputs=input_text,
|
| 1185 |
label="",
|
| 1186 |
)
|
| 1187 |
+
|
| 1188 |
# Supported Biomarkers
|
| 1189 |
with gr.Accordion("📊 Supported Biomarkers", open=False):
|
| 1190 |
gr.HTML("""
|
|
|
|
| 1215 |
</div>
|
| 1216 |
</div>
|
| 1217 |
""")
|
| 1218 |
+
|
| 1219 |
# ----- RIGHT PANEL: RESULTS -----
|
| 1220 |
with gr.Column(scale=3, min_width=500):
|
| 1221 |
gr.HTML('<div class="section-title">📊 Analysis Results</div>')
|
| 1222 |
+
|
| 1223 |
with gr.Tabs() as result_tabs:
|
| 1224 |
with gr.Tab("📋 Summary", id="summary"):
|
| 1225 |
summary_output = gr.Markdown(
|
|
|
|
| 1232 |
""",
|
| 1233 |
elem_classes="summary-output"
|
| 1234 |
)
|
| 1235 |
+
|
| 1236 |
with gr.Tab("🔍 Detailed JSON", id="json"):
|
| 1237 |
details_output = gr.Code(
|
| 1238 |
label="",
|
|
|
|
| 1240 |
lines=30,
|
| 1241 |
show_label=False,
|
| 1242 |
)
|
| 1243 |
+
|
| 1244 |
# ==================== TAB 2: MEDICAL Q&A ====================
|
| 1245 |
with gr.Tab("💬 Medical Q&A", id="qa-tab"):
|
| 1246 |
+
|
| 1247 |
gr.HTML("""
|
| 1248 |
<div style="margin-bottom: 20px;">
|
| 1249 |
<h3 style="color: #1e3a5f; margin: 0 0 8px 0;">💬 Medical Q&A Assistant</h3>
|
|
|
|
| 1252 |
</p>
|
| 1253 |
</div>
|
| 1254 |
""")
|
| 1255 |
+
|
| 1256 |
with gr.Row(equal_height=False):
|
| 1257 |
with gr.Column(scale=1):
|
| 1258 |
qa_context = gr.Textbox(
|
|
|
|
| 1261 |
lines=3,
|
| 1262 |
max_lines=6,
|
| 1263 |
)
|
| 1264 |
+
qa_model = gr.Dropdown(
|
| 1265 |
+
choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
|
| 1266 |
+
value="llama-3.3-70b-versatile",
|
| 1267 |
+
label="LLM Provider/Model"
|
| 1268 |
+
)
|
| 1269 |
qa_question = gr.Textbox(
|
| 1270 |
label="Your Question",
|
| 1271 |
placeholder="Ask any medical question...\n• What do my elevated glucose levels indicate?\n• Should I be concerned about my HbA1c of 7.5%?\n• What lifestyle changes help with prediabetes?",
|
|
|
|
| 1281 |
)
|
| 1282 |
qa_clear_btn = gr.Button(
|
| 1283 |
"🗑️ Clear",
|
| 1284 |
+
variant="secondary",
|
| 1285 |
size="lg",
|
| 1286 |
scale=1,
|
| 1287 |
)
|
| 1288 |
+
|
| 1289 |
# Quick question examples
|
| 1290 |
gr.HTML('<h4 style="margin-top: 16px; color: #1e3a5f;">Example Questions</h4>')
|
| 1291 |
qa_examples = gr.Examples(
|
|
|
|
| 1298 |
inputs=[qa_question, qa_context],
|
| 1299 |
label="",
|
| 1300 |
)
|
| 1301 |
+
|
| 1302 |
with gr.Column(scale=2):
|
| 1303 |
gr.HTML('<h4 style="color: #1e3a5f; margin-bottom: 12px;">📝 Answer</h4>')
|
| 1304 |
+
qa_answer = gr.Chatbot(
|
| 1305 |
+
label="Medical Q&A History",
|
| 1306 |
+
height=600,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1307 |
elem_classes="qa-output"
|
| 1308 |
)
|
| 1309 |
+
|
| 1310 |
# Q&A Event Handlers
|
| 1311 |
qa_submit_btn.click(
|
| 1312 |
fn=streaming_answer,
|
| 1313 |
+
inputs=[qa_question, qa_context, qa_answer, qa_model],
|
| 1314 |
outputs=qa_answer,
|
| 1315 |
show_progress="minimal",
|
| 1316 |
+
).then(
|
| 1317 |
+
fn=lambda: "",
|
| 1318 |
+
outputs=qa_question
|
| 1319 |
)
|
| 1320 |
+
|
| 1321 |
qa_clear_btn.click(
|
| 1322 |
+
fn=lambda: ([], ""),
|
| 1323 |
+
outputs=[qa_answer, qa_question],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1324 |
)
|
| 1325 |
+
|
| 1326 |
+
# ==================== TAB 3: SEARCH KNOWLEDGE BASE ====================
|
| 1327 |
+
with gr.Tab("🔍 Search Knowledge Base", id="search-tab"):
|
| 1328 |
+
with gr.Row():
|
| 1329 |
+
search_input = gr.Textbox(
|
| 1330 |
+
label="Search Query",
|
| 1331 |
+
placeholder="e.g., diabetes management guidelines",
|
| 1332 |
+
lines=2,
|
| 1333 |
+
scale=3
|
| 1334 |
+
)
|
| 1335 |
+
search_mode = gr.Radio(
|
| 1336 |
+
choices=["hybrid", "bm25", "vector"],
|
| 1337 |
+
value="hybrid",
|
| 1338 |
+
label="Search Strategy",
|
| 1339 |
+
scale=1
|
| 1340 |
+
)
|
| 1341 |
+
search_btn = gr.Button("Search", variant="primary")
|
| 1342 |
+
search_output = gr.Textbox(label="Results", lines=20, interactive=False)
|
| 1343 |
+
|
| 1344 |
+
search_btn.click(fn=hf_search, inputs=[search_input, search_mode], outputs=search_output)
|
| 1345 |
+
|
| 1346 |
# ===== HOW IT WORKS =====
|
| 1347 |
gr.HTML('<div class="section-title" style="margin-top: 32px;">🤖 How It Works</div>')
|
| 1348 |
+
|
| 1349 |
gr.HTML("""
|
| 1350 |
<div class="agent-grid">
|
| 1351 |
<div class="agent-card">
|
|
|
|
| 1374 |
</div>
|
| 1375 |
</div>
|
| 1376 |
""")
|
| 1377 |
+
|
| 1378 |
# ===== DISCLAIMER =====
|
| 1379 |
gr.HTML("""
|
| 1380 |
<div class="disclaimer">
|
|
|
|
| 1384 |
clinical guidelines and may not account for your specific medical history.
|
| 1385 |
</div>
|
| 1386 |
""")
|
| 1387 |
+
|
| 1388 |
# ===== FOOTER =====
|
| 1389 |
gr.HTML("""
|
| 1390 |
<div style="text-align: center; padding: 24px; color: #94a3b8; font-size: 0.85em; margin-top: 24px;">
|
|
|
|
| 1399 |
</p>
|
| 1400 |
</div>
|
| 1401 |
""")
|
| 1402 |
+
|
| 1403 |
# ===== EVENT HANDLERS =====
|
| 1404 |
analyze_btn.click(
|
| 1405 |
fn=analyze_biomarkers,
|
|
|
|
| 1407 |
outputs=[summary_output, details_output, status_output],
|
| 1408 |
show_progress="full",
|
| 1409 |
)
|
| 1410 |
+
|
| 1411 |
clear_btn.click(
|
| 1412 |
fn=lambda: ("", """
|
| 1413 |
<div style="text-align: center; padding: 60px 20px; color: #94a3b8;">
|
|
|
|
| 1418 |
""", "", ""),
|
| 1419 |
outputs=[input_text, summary_output, details_output, status_output],
|
| 1420 |
)
|
| 1421 |
+
|
| 1422 |
return demo
|
| 1423 |
|
| 1424 |
|
|
|
|
| 1428 |
|
| 1429 |
if __name__ == "__main__":
|
| 1430 |
logger.info("Starting MediGuard AI Gradio App...")
|
| 1431 |
+
|
| 1432 |
demo = create_demo()
|
| 1433 |
+
|
| 1434 |
# Launch with HF Spaces compatible settings
|
| 1435 |
demo.launch(
|
| 1436 |
server_name="0.0.0.0",
|
pytest.ini
CHANGED
|
@@ -2,3 +2,6 @@
|
|
| 2 |
filterwarnings =
|
| 3 |
ignore::langchain_core._api.deprecation.LangChainDeprecationWarning
|
| 4 |
ignore:.*class.*HuggingFaceEmbeddings.*was deprecated.*:DeprecationWarning
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
filterwarnings =
|
| 3 |
ignore::langchain_core._api.deprecation.LangChainDeprecationWarning
|
| 4 |
ignore:.*class.*HuggingFaceEmbeddings.*was deprecated.*:DeprecationWarning
|
| 5 |
+
|
| 6 |
+
markers =
|
| 7 |
+
integration: mark a test as an integration test.
|
requirements.txt
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
# MediGuard AI RAG-Helper - Dependencies
|
| 2 |
-
|
| 3 |
-
# Core Framework
|
| 4 |
-
langchain>=0.1.0
|
| 5 |
-
langgraph>=0.0.20
|
| 6 |
-
langchain-community>=0.0.13
|
| 7 |
-
langchain-core>=0.1.10
|
| 8 |
-
|
| 9 |
-
# LLM Providers (Cloud - FREE tiers available)
|
| 10 |
-
langchain-groq>=0.1.0 # Groq API (FREE tier, llama-3.3-70b)
|
| 11 |
-
langchain-google-genai>=1.0.0 # Google Gemini (FREE tier)
|
| 12 |
-
|
| 13 |
-
# Local LLM (optional, for offline use)
|
| 14 |
-
# ollama>=0.1.6
|
| 15 |
-
|
| 16 |
-
# Vector Store & Embeddings
|
| 17 |
-
faiss-cpu>=1.9.0
|
| 18 |
-
sentence-transformers>=2.2.2
|
| 19 |
-
|
| 20 |
-
# Document Processing
|
| 21 |
-
pypdf>=3.17.4
|
| 22 |
-
pydantic>=2.5.3
|
| 23 |
-
|
| 24 |
-
# Data Handling
|
| 25 |
-
pandas>=2.1.4
|
| 26 |
-
|
| 27 |
-
# Environment & Configuration
|
| 28 |
-
python-dotenv>=1.0.0
|
| 29 |
-
|
| 30 |
-
# Utilities
|
| 31 |
-
numpy>=1.26.2
|
| 32 |
-
matplotlib>=3.8.2
|
| 33 |
-
|
| 34 |
-
# Optional: improved readability scoring for evaluations
|
| 35 |
-
textstat>=0.7.3
|
| 36 |
-
|
| 37 |
-
# Optional: HuggingFace embedding provider
|
| 38 |
-
# langchain-huggingface>=0.0.1
|
| 39 |
-
|
| 40 |
-
# Optional: Ollama local LLM provider
|
| 41 |
-
# langchain-ollama>=0.0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/chat.py
CHANGED
|
@@ -4,9 +4,9 @@ Enables natural language conversation with the RAG system
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import json
|
| 7 |
-
import sys
|
| 8 |
-
import os
|
| 9 |
import logging
|
|
|
|
|
|
|
| 10 |
import warnings
|
| 11 |
|
| 12 |
# ── Silence HuggingFace / transformers noise BEFORE any ML library is loaded ──
|
|
@@ -21,9 +21,9 @@ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
|
| 21 |
warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
|
| 22 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 23 |
|
| 24 |
-
from pathlib import Path
|
| 25 |
-
from typing import Dict, Any, Tuple
|
| 26 |
from datetime import datetime
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Set UTF-8 encoding for Windows console
|
| 29 |
if sys.platform == 'win32':
|
|
@@ -40,11 +40,11 @@ if sys.platform == 'win32':
|
|
| 40 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 41 |
|
| 42 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
| 43 |
from src.biomarker_normalization import normalize_biomarker_name
|
| 44 |
from src.llm_config import get_chat_model
|
| 45 |
-
from src.workflow import create_guild
|
| 46 |
from src.state import PatientInput
|
| 47 |
-
|
| 48 |
|
| 49 |
# ============================================================================
|
| 50 |
# BIOMARKER EXTRACTION PROMPT
|
|
@@ -82,7 +82,7 @@ If you cannot find any biomarkers, return {{"biomarkers": {{}}, "patient_context
|
|
| 82 |
# Component 1: Biomarker Extraction
|
| 83 |
# ============================================================================
|
| 84 |
|
| 85 |
-
def _parse_llm_json(content: str) ->
|
| 86 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 87 |
text = content.strip()
|
| 88 |
|
|
@@ -101,7 +101,7 @@ def _parse_llm_json(content: str) -> Dict[str, Any]:
|
|
| 101 |
raise
|
| 102 |
|
| 103 |
|
| 104 |
-
def extract_biomarkers(user_message: str) ->
|
| 105 |
"""
|
| 106 |
Extract biomarker values from natural language using LLM.
|
| 107 |
|
|
@@ -111,17 +111,17 @@ def extract_biomarkers(user_message: str) -> Tuple[Dict[str, float], Dict[str, A
|
|
| 111 |
try:
|
| 112 |
llm = get_chat_model(temperature=0.0)
|
| 113 |
prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
|
| 114 |
-
|
| 115 |
chain = prompt | llm
|
| 116 |
response = chain.invoke({"user_message": user_message})
|
| 117 |
-
|
| 118 |
# Parse JSON from LLM response
|
| 119 |
content = response.content.strip()
|
| 120 |
-
|
| 121 |
extracted = _parse_llm_json(content)
|
| 122 |
biomarkers = extracted.get("biomarkers", {})
|
| 123 |
patient_context = extracted.get("patient_context", {})
|
| 124 |
-
|
| 125 |
# Normalize biomarker names
|
| 126 |
normalized = {}
|
| 127 |
for key, value in biomarkers.items():
|
|
@@ -131,12 +131,12 @@ def extract_biomarkers(user_message: str) -> Tuple[Dict[str, float], Dict[str, A
|
|
| 131 |
except (ValueError, TypeError) as e:
|
| 132 |
print(f"⚠️ Skipping invalid value for {key}: {value} (error: {e})")
|
| 133 |
continue
|
| 134 |
-
|
| 135 |
# Clean up patient context (remove null values)
|
| 136 |
patient_context = {k: v for k, v in patient_context.items() if v is not None}
|
| 137 |
-
|
| 138 |
return normalized, patient_context
|
| 139 |
-
|
| 140 |
except Exception as e:
|
| 141 |
print(f"⚠️ Extraction failed: {e}")
|
| 142 |
import traceback
|
|
@@ -148,7 +148,7 @@ def extract_biomarkers(user_message: str) -> Tuple[Dict[str, float], Dict[str, A
|
|
| 148 |
# Component 2: Disease Prediction
|
| 149 |
# ============================================================================
|
| 150 |
|
| 151 |
-
def predict_disease_simple(biomarkers:
|
| 152 |
"""
|
| 153 |
Simple rule-based disease prediction based on key biomarkers.
|
| 154 |
"""
|
|
@@ -159,15 +159,15 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 159 |
"Thrombocytopenia": 0.0,
|
| 160 |
"Thalassemia": 0.0
|
| 161 |
}
|
| 162 |
-
|
| 163 |
# Helper: check both abbreviated and normalized biomarker names
|
| 164 |
# Returns None when biomarker is not present (avoids false triggers)
|
| 165 |
def _get(name, *alt_names):
|
| 166 |
-
val = biomarkers.get(name
|
| 167 |
if val is not None:
|
| 168 |
return val
|
| 169 |
for alt in alt_names:
|
| 170 |
-
val = biomarkers.get(alt
|
| 171 |
if val is not None:
|
| 172 |
return val
|
| 173 |
return None
|
|
@@ -181,7 +181,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 181 |
scores["Diabetes"] += 0.2
|
| 182 |
if hba1c is not None and hba1c >= 6.5:
|
| 183 |
scores["Diabetes"] += 0.5
|
| 184 |
-
|
| 185 |
# Anemia indicators
|
| 186 |
hemoglobin = _get("Hemoglobin")
|
| 187 |
mcv = _get("Mean Corpuscular Volume", "MCV")
|
|
@@ -191,7 +191,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 191 |
scores["Anemia"] += 0.2
|
| 192 |
if mcv is not None and mcv < 80:
|
| 193 |
scores["Anemia"] += 0.2
|
| 194 |
-
|
| 195 |
# Heart disease indicators
|
| 196 |
cholesterol = _get("Cholesterol")
|
| 197 |
troponin = _get("Troponin")
|
|
@@ -202,32 +202,32 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 202 |
scores["Heart Disease"] += 0.6
|
| 203 |
if ldl is not None and ldl > 190:
|
| 204 |
scores["Heart Disease"] += 0.2
|
| 205 |
-
|
| 206 |
# Thrombocytopenia indicators
|
| 207 |
platelets = _get("Platelets")
|
| 208 |
if platelets is not None and platelets < 150000:
|
| 209 |
scores["Thrombocytopenia"] += 0.6
|
| 210 |
if platelets is not None and platelets < 50000:
|
| 211 |
scores["Thrombocytopenia"] += 0.3
|
| 212 |
-
|
| 213 |
# Thalassemia indicators (complex, simplified here)
|
| 214 |
if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
|
| 215 |
scores["Thalassemia"] += 0.4
|
| 216 |
-
|
| 217 |
# Find top prediction
|
| 218 |
top_disease = max(scores, key=scores.get)
|
| 219 |
confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
|
| 220 |
-
|
| 221 |
if confidence == 0.0:
|
| 222 |
top_disease = "Undetermined"
|
| 223 |
-
|
| 224 |
# Normalize probabilities to sum to 1.0
|
| 225 |
total = sum(scores.values())
|
| 226 |
if total > 0:
|
| 227 |
probabilities = {k: v / total for k, v in scores.items()}
|
| 228 |
else:
|
| 229 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 230 |
-
|
| 231 |
return {
|
| 232 |
"disease": top_disease,
|
| 233 |
"confidence": confidence,
|
|
@@ -235,14 +235,14 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 235 |
}
|
| 236 |
|
| 237 |
|
| 238 |
-
def predict_disease_llm(biomarkers:
|
| 239 |
"""
|
| 240 |
Use LLM to predict most likely disease based on biomarker pattern.
|
| 241 |
Falls back to rule-based if LLM fails.
|
| 242 |
"""
|
| 243 |
try:
|
| 244 |
llm = get_chat_model(temperature=0.0)
|
| 245 |
-
|
| 246 |
prompt = f"""You are a medical AI assistant. Based on these biomarker values,
|
| 247 |
predict the most likely disease from: Diabetes, Anemia, Heart Disease, Thrombocytopenia, Thalassemia.
|
| 248 |
|
|
@@ -265,18 +265,18 @@ Return ONLY valid JSON (no other text):
|
|
| 265 |
}}
|
| 266 |
}}
|
| 267 |
"""
|
| 268 |
-
|
| 269 |
response = llm.invoke(prompt)
|
| 270 |
content = response.content.strip()
|
| 271 |
-
|
| 272 |
prediction = _parse_llm_json(content)
|
| 273 |
-
|
| 274 |
# Validate required fields
|
| 275 |
if "disease" in prediction and "confidence" in prediction and "probabilities" in prediction:
|
| 276 |
return prediction
|
| 277 |
else:
|
| 278 |
raise ValueError("Invalid prediction format")
|
| 279 |
-
|
| 280 |
except Exception as e:
|
| 281 |
print(f"⚠️ LLM prediction failed ({e}), using rule-based fallback")
|
| 282 |
import traceback
|
|
@@ -288,7 +288,7 @@ Return ONLY valid JSON (no other text):
|
|
| 288 |
# Component 3: Conversational Formatter
|
| 289 |
# ============================================================================
|
| 290 |
|
| 291 |
-
def _coerce_to_dict(obj) ->
|
| 292 |
"""Convert a Pydantic model or arbitrary object to a plain dict."""
|
| 293 |
if isinstance(obj, dict):
|
| 294 |
return obj
|
|
@@ -299,7 +299,7 @@ def _coerce_to_dict(obj) -> Dict:
|
|
| 299 |
return {}
|
| 300 |
|
| 301 |
|
| 302 |
-
def format_conversational(result:
|
| 303 |
"""
|
| 304 |
Format technical JSON output into conversational response.
|
| 305 |
"""
|
|
@@ -313,22 +313,22 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
|
|
| 313 |
confidence = result.get("confidence_assessment", {}) or {}
|
| 314 |
# Normalize: items may be Pydantic SafetyAlert objects or plain dicts
|
| 315 |
alerts = [_coerce_to_dict(a) for a in (result.get("safety_alerts") or [])]
|
| 316 |
-
|
| 317 |
disease = prediction.get("primary_disease", "Unknown")
|
| 318 |
conf_score = prediction.get("confidence", 0.0)
|
| 319 |
-
|
| 320 |
# Build conversational response
|
| 321 |
response = []
|
| 322 |
-
|
| 323 |
# 1. Greeting and main finding
|
| 324 |
response.append(f"Hi {user_name}! 👋\n")
|
| 325 |
-
response.append(
|
| 326 |
-
|
| 327 |
# 2. Primary diagnosis with confidence
|
| 328 |
emoji = "🔴" if conf_score >= 0.8 else "🟡" if conf_score >= 0.6 else "🟢"
|
| 329 |
response.append(f"{emoji} **Primary Finding:** {disease}")
|
| 330 |
response.append(f" Confidence: {conf_score:.0%}\n")
|
| 331 |
-
|
| 332 |
# 3. Critical safety alerts (if any)
|
| 333 |
critical_alerts = [a for a in alerts if a.get("severity") == "CRITICAL"]
|
| 334 |
if critical_alerts:
|
|
@@ -337,7 +337,7 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
|
|
| 337 |
response.append(f" • {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}")
|
| 338 |
response.append(f" → {alert.get('action', 'Consult healthcare provider')}")
|
| 339 |
response.append("")
|
| 340 |
-
|
| 341 |
# 4. Key drivers explanation
|
| 342 |
key_drivers = prediction.get("key_drivers", [])
|
| 343 |
if key_drivers:
|
|
@@ -351,7 +351,7 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
|
|
| 351 |
explanation = explanation[:147] + "..."
|
| 352 |
response.append(f" • **{biomarker}** ({value}): {explanation}")
|
| 353 |
response.append("")
|
| 354 |
-
|
| 355 |
# 5. What to do next (immediate actions)
|
| 356 |
immediate = recommendations.get("immediate_actions", [])
|
| 357 |
if immediate:
|
|
@@ -359,7 +359,7 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
|
|
| 359 |
for i, action in enumerate(immediate[:3], 1):
|
| 360 |
response.append(f" {i}. {action}")
|
| 361 |
response.append("")
|
| 362 |
-
|
| 363 |
# 6. Lifestyle recommendations
|
| 364 |
lifestyle = recommendations.get("lifestyle_changes", [])
|
| 365 |
if lifestyle:
|
|
@@ -367,11 +367,11 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
|
|
| 367 |
for i, change in enumerate(lifestyle[:3], 1):
|
| 368 |
response.append(f" {i}. {change}")
|
| 369 |
response.append("")
|
| 370 |
-
|
| 371 |
# 7. Disclaimer
|
| 372 |
response.append("ℹ️ **Important:** This is an AI-assisted analysis, NOT medical advice.")
|
| 373 |
response.append(" Please consult a healthcare professional for proper diagnosis and treatment.\n")
|
| 374 |
-
|
| 375 |
return "\n".join(response)
|
| 376 |
|
| 377 |
|
|
@@ -397,7 +397,7 @@ def run_example_case(guild):
|
|
| 397 |
"""Run example diabetes patient case"""
|
| 398 |
print("\n📋 Running Example: Type 2 Diabetes Patient")
|
| 399 |
print(" 52-year-old male with elevated glucose and HbA1c\n")
|
| 400 |
-
|
| 401 |
example_biomarkers = {
|
| 402 |
"Glucose": 185.0,
|
| 403 |
"HbA1c": 8.2,
|
|
@@ -411,7 +411,7 @@ def run_example_case(guild):
|
|
| 411 |
"Systolic Blood Pressure": 145,
|
| 412 |
"Diastolic Blood Pressure": 92
|
| 413 |
}
|
| 414 |
-
|
| 415 |
prediction = {
|
| 416 |
"disease": "Diabetes",
|
| 417 |
"confidence": 0.87,
|
|
@@ -423,16 +423,16 @@ def run_example_case(guild):
|
|
| 423 |
"Thalassemia": 0.01
|
| 424 |
}
|
| 425 |
}
|
| 426 |
-
|
| 427 |
patient_input = PatientInput(
|
| 428 |
biomarkers=example_biomarkers,
|
| 429 |
model_prediction=prediction,
|
| 430 |
patient_context={"age": 52, "gender": "male", "bmi": 31.2}
|
| 431 |
)
|
| 432 |
-
|
| 433 |
print("🔄 Running analysis...\n")
|
| 434 |
result = guild.run(patient_input)
|
| 435 |
-
|
| 436 |
response = format_conversational(result.get("final_response", result), "there")
|
| 437 |
print("\n" + "="*70)
|
| 438 |
print("🤖 RAG-BOT:")
|
|
@@ -441,7 +441,7 @@ def run_example_case(guild):
|
|
| 441 |
print("="*70 + "\n")
|
| 442 |
|
| 443 |
|
| 444 |
-
def save_report(result:
|
| 445 |
"""Save detailed JSON report to file"""
|
| 446 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 447 |
|
|
@@ -505,7 +505,7 @@ def chat_interface():
|
|
| 505 |
print(" 3. Type 'help' for biomarker list")
|
| 506 |
print(" 4. Type 'quit' to exit\n")
|
| 507 |
print("="*70 + "\n")
|
| 508 |
-
|
| 509 |
# Initialize guild (one-time setup)
|
| 510 |
print("🔧 Initializing medical knowledge system...")
|
| 511 |
try:
|
|
@@ -518,78 +518,78 @@ def chat_interface():
|
|
| 518 |
print(" • Vector store exists (run: python scripts/setup_embeddings.py)")
|
| 519 |
print(" • Internet connection is available for cloud LLM")
|
| 520 |
return
|
| 521 |
-
|
| 522 |
# Main conversation loop
|
| 523 |
conversation_history = []
|
| 524 |
user_name = "there"
|
| 525 |
-
|
| 526 |
while True:
|
| 527 |
try:
|
| 528 |
# Get user input
|
| 529 |
user_input = input("You: ").strip()
|
| 530 |
-
|
| 531 |
if not user_input:
|
| 532 |
continue
|
| 533 |
-
|
| 534 |
# Handle special commands
|
| 535 |
if user_input.lower() in ['quit', 'exit', 'q']:
|
| 536 |
print("\n👋 Thank you for using MediGuard AI. Stay healthy!")
|
| 537 |
break
|
| 538 |
-
|
| 539 |
if user_input.lower() == 'help':
|
| 540 |
print_biomarker_help()
|
| 541 |
continue
|
| 542 |
-
|
| 543 |
if user_input.lower() == 'example':
|
| 544 |
run_example_case(guild)
|
| 545 |
continue
|
| 546 |
-
|
| 547 |
# Extract biomarkers from natural language
|
| 548 |
print("\n🔍 Analyzing your input...")
|
| 549 |
biomarkers, patient_context = extract_biomarkers(user_input)
|
| 550 |
-
|
| 551 |
if not biomarkers:
|
| 552 |
print("❌ I couldn't find any biomarker values in your message.")
|
| 553 |
print(" Try: 'My glucose is 140 and HbA1c is 7.5'")
|
| 554 |
print(" Or type 'help' to see all biomarkers I can analyze.\n")
|
| 555 |
continue
|
| 556 |
-
|
| 557 |
print(f"✅ Found {len(biomarkers)} biomarker(s): {', '.join(biomarkers.keys())}")
|
| 558 |
-
|
| 559 |
# Check if we have enough biomarkers (minimum 2)
|
| 560 |
if len(biomarkers) < 2:
|
| 561 |
print("⚠️ I need at least 2 biomarkers for a reliable analysis.")
|
| 562 |
print(" Can you provide more values?\n")
|
| 563 |
continue
|
| 564 |
-
|
| 565 |
# Generate disease prediction
|
| 566 |
print("🧠 Predicting likely condition...")
|
| 567 |
prediction = predict_disease_llm(biomarkers, patient_context)
|
| 568 |
print(f"✅ Predicted: {prediction['disease']} ({prediction['confidence']:.0%} confidence)")
|
| 569 |
-
|
| 570 |
# Create PatientInput
|
| 571 |
patient_input = PatientInput(
|
| 572 |
biomarkers=biomarkers,
|
| 573 |
model_prediction=prediction,
|
| 574 |
patient_context=patient_context if patient_context else {"source": "chat"}
|
| 575 |
)
|
| 576 |
-
|
| 577 |
# Run full RAG workflow
|
| 578 |
print("📚 Consulting medical knowledge base...")
|
| 579 |
print(" (This may take 15-25 seconds...)\n")
|
| 580 |
-
|
| 581 |
result = guild.run(patient_input)
|
| 582 |
-
|
| 583 |
# Format conversational response
|
| 584 |
response = format_conversational(result.get("final_response", result), user_name)
|
| 585 |
-
|
| 586 |
# Display response
|
| 587 |
print("\n" + "="*70)
|
| 588 |
print("🤖 RAG-BOT:")
|
| 589 |
print("="*70)
|
| 590 |
print(response)
|
| 591 |
print("="*70 + "\n")
|
| 592 |
-
|
| 593 |
# Save to history
|
| 594 |
conversation_history.append({
|
| 595 |
"user_input": user_input,
|
|
@@ -597,16 +597,16 @@ def chat_interface():
|
|
| 597 |
"prediction": prediction,
|
| 598 |
"result": result
|
| 599 |
})
|
| 600 |
-
|
| 601 |
# Ask if user wants to save report
|
| 602 |
save_choice = input("💾 Save detailed report to file? (y/n): ").strip().lower()
|
| 603 |
if save_choice == 'y':
|
| 604 |
save_report(result, biomarkers)
|
| 605 |
-
|
| 606 |
print("\nYou can:")
|
| 607 |
print(" • Enter more biomarkers for a new analysis")
|
| 608 |
print(" • Type 'quit' to exit\n")
|
| 609 |
-
|
| 610 |
except KeyboardInterrupt:
|
| 611 |
print("\n\n👋 Interrupted. Thank you for using MediGuard AI!")
|
| 612 |
break
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import json
|
|
|
|
|
|
|
| 7 |
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
import warnings
|
| 11 |
|
| 12 |
# ── Silence HuggingFace / transformers noise BEFORE any ML library is loaded ──
|
|
|
|
| 21 |
warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
|
| 22 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 23 |
|
|
|
|
|
|
|
| 24 |
from datetime import datetime
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any
|
| 27 |
|
| 28 |
# Set UTF-8 encoding for Windows console
|
| 29 |
if sys.platform == 'win32':
|
|
|
|
| 40 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 41 |
|
| 42 |
from langchain_core.prompts import ChatPromptTemplate
|
| 43 |
+
|
| 44 |
from src.biomarker_normalization import normalize_biomarker_name
|
| 45 |
from src.llm_config import get_chat_model
|
|
|
|
| 46 |
from src.state import PatientInput
|
| 47 |
+
from src.workflow import create_guild
|
| 48 |
|
| 49 |
# ============================================================================
|
| 50 |
# BIOMARKER EXTRACTION PROMPT
|
|
|
|
| 82 |
# Component 1: Biomarker Extraction
|
| 83 |
# ============================================================================
|
| 84 |
|
| 85 |
+
def _parse_llm_json(content: str) -> dict[str, Any]:
|
| 86 |
"""Parse JSON payload from LLM output with fallback recovery."""
|
| 87 |
text = content.strip()
|
| 88 |
|
|
|
|
| 101 |
raise
|
| 102 |
|
| 103 |
|
| 104 |
+
def extract_biomarkers(user_message: str) -> tuple[dict[str, float], dict[str, Any]]:
|
| 105 |
"""
|
| 106 |
Extract biomarker values from natural language using LLM.
|
| 107 |
|
|
|
|
| 111 |
try:
|
| 112 |
llm = get_chat_model(temperature=0.0)
|
| 113 |
prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
|
| 114 |
+
|
| 115 |
chain = prompt | llm
|
| 116 |
response = chain.invoke({"user_message": user_message})
|
| 117 |
+
|
| 118 |
# Parse JSON from LLM response
|
| 119 |
content = response.content.strip()
|
| 120 |
+
|
| 121 |
extracted = _parse_llm_json(content)
|
| 122 |
biomarkers = extracted.get("biomarkers", {})
|
| 123 |
patient_context = extracted.get("patient_context", {})
|
| 124 |
+
|
| 125 |
# Normalize biomarker names
|
| 126 |
normalized = {}
|
| 127 |
for key, value in biomarkers.items():
|
|
|
|
| 131 |
except (ValueError, TypeError) as e:
|
| 132 |
print(f"⚠️ Skipping invalid value for {key}: {value} (error: {e})")
|
| 133 |
continue
|
| 134 |
+
|
| 135 |
# Clean up patient context (remove null values)
|
| 136 |
patient_context = {k: v for k, v in patient_context.items() if v is not None}
|
| 137 |
+
|
| 138 |
return normalized, patient_context
|
| 139 |
+
|
| 140 |
except Exception as e:
|
| 141 |
print(f"⚠️ Extraction failed: {e}")
|
| 142 |
import traceback
|
|
|
|
| 148 |
# Component 2: Disease Prediction
|
| 149 |
# ============================================================================
|
| 150 |
|
| 151 |
+
def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 152 |
"""
|
| 153 |
Simple rule-based disease prediction based on key biomarkers.
|
| 154 |
"""
|
|
|
|
| 159 |
"Thrombocytopenia": 0.0,
|
| 160 |
"Thalassemia": 0.0
|
| 161 |
}
|
| 162 |
+
|
| 163 |
# Helper: check both abbreviated and normalized biomarker names
|
| 164 |
# Returns None when biomarker is not present (avoids false triggers)
|
| 165 |
def _get(name, *alt_names):
|
| 166 |
+
val = biomarkers.get(name)
|
| 167 |
if val is not None:
|
| 168 |
return val
|
| 169 |
for alt in alt_names:
|
| 170 |
+
val = biomarkers.get(alt)
|
| 171 |
if val is not None:
|
| 172 |
return val
|
| 173 |
return None
|
|
|
|
| 181 |
scores["Diabetes"] += 0.2
|
| 182 |
if hba1c is not None and hba1c >= 6.5:
|
| 183 |
scores["Diabetes"] += 0.5
|
| 184 |
+
|
| 185 |
# Anemia indicators
|
| 186 |
hemoglobin = _get("Hemoglobin")
|
| 187 |
mcv = _get("Mean Corpuscular Volume", "MCV")
|
|
|
|
| 191 |
scores["Anemia"] += 0.2
|
| 192 |
if mcv is not None and mcv < 80:
|
| 193 |
scores["Anemia"] += 0.2
|
| 194 |
+
|
| 195 |
# Heart disease indicators
|
| 196 |
cholesterol = _get("Cholesterol")
|
| 197 |
troponin = _get("Troponin")
|
|
|
|
| 202 |
scores["Heart Disease"] += 0.6
|
| 203 |
if ldl is not None and ldl > 190:
|
| 204 |
scores["Heart Disease"] += 0.2
|
| 205 |
+
|
| 206 |
# Thrombocytopenia indicators
|
| 207 |
platelets = _get("Platelets")
|
| 208 |
if platelets is not None and platelets < 150000:
|
| 209 |
scores["Thrombocytopenia"] += 0.6
|
| 210 |
if platelets is not None and platelets < 50000:
|
| 211 |
scores["Thrombocytopenia"] += 0.3
|
| 212 |
+
|
| 213 |
# Thalassemia indicators (complex, simplified here)
|
| 214 |
if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
|
| 215 |
scores["Thalassemia"] += 0.4
|
| 216 |
+
|
| 217 |
# Find top prediction
|
| 218 |
top_disease = max(scores, key=scores.get)
|
| 219 |
confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
|
| 220 |
+
|
| 221 |
if confidence == 0.0:
|
| 222 |
top_disease = "Undetermined"
|
| 223 |
+
|
| 224 |
# Normalize probabilities to sum to 1.0
|
| 225 |
total = sum(scores.values())
|
| 226 |
if total > 0:
|
| 227 |
probabilities = {k: v / total for k, v in scores.items()}
|
| 228 |
else:
|
| 229 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 230 |
+
|
| 231 |
return {
|
| 232 |
"disease": top_disease,
|
| 233 |
"confidence": confidence,
|
|
|
|
| 235 |
}
|
| 236 |
|
| 237 |
|
| 238 |
+
def predict_disease_llm(biomarkers: dict[str, float], patient_context: dict) -> dict[str, Any]:
|
| 239 |
"""
|
| 240 |
Use LLM to predict most likely disease based on biomarker pattern.
|
| 241 |
Falls back to rule-based if LLM fails.
|
| 242 |
"""
|
| 243 |
try:
|
| 244 |
llm = get_chat_model(temperature=0.0)
|
| 245 |
+
|
| 246 |
prompt = f"""You are a medical AI assistant. Based on these biomarker values,
|
| 247 |
predict the most likely disease from: Diabetes, Anemia, Heart Disease, Thrombocytopenia, Thalassemia.
|
| 248 |
|
|
|
|
| 265 |
}}
|
| 266 |
}}
|
| 267 |
"""
|
| 268 |
+
|
| 269 |
response = llm.invoke(prompt)
|
| 270 |
content = response.content.strip()
|
| 271 |
+
|
| 272 |
prediction = _parse_llm_json(content)
|
| 273 |
+
|
| 274 |
# Validate required fields
|
| 275 |
if "disease" in prediction and "confidence" in prediction and "probabilities" in prediction:
|
| 276 |
return prediction
|
| 277 |
else:
|
| 278 |
raise ValueError("Invalid prediction format")
|
| 279 |
+
|
| 280 |
except Exception as e:
|
| 281 |
print(f"⚠️ LLM prediction failed ({e}), using rule-based fallback")
|
| 282 |
import traceback
|
|
|
|
| 288 |
# Component 3: Conversational Formatter
|
| 289 |
# ============================================================================
|
| 290 |
|
| 291 |
+
def _coerce_to_dict(obj) -> dict:
|
| 292 |
"""Convert a Pydantic model or arbitrary object to a plain dict."""
|
| 293 |
if isinstance(obj, dict):
|
| 294 |
return obj
|
|
|
|
| 299 |
return {}
|
| 300 |
|
| 301 |
|
| 302 |
+
def format_conversational(result: dict[str, Any], user_name: str = "there") -> str:
|
| 303 |
"""
|
| 304 |
Format technical JSON output into conversational response.
|
| 305 |
"""
|
|
|
|
| 313 |
confidence = result.get("confidence_assessment", {}) or {}
|
| 314 |
# Normalize: items may be Pydantic SafetyAlert objects or plain dicts
|
| 315 |
alerts = [_coerce_to_dict(a) for a in (result.get("safety_alerts") or [])]
|
| 316 |
+
|
| 317 |
disease = prediction.get("primary_disease", "Unknown")
|
| 318 |
conf_score = prediction.get("confidence", 0.0)
|
| 319 |
+
|
| 320 |
# Build conversational response
|
| 321 |
response = []
|
| 322 |
+
|
| 323 |
# 1. Greeting and main finding
|
| 324 |
response.append(f"Hi {user_name}! 👋\n")
|
| 325 |
+
response.append("Based on your biomarkers, I analyzed your results.\n")
|
| 326 |
+
|
| 327 |
# 2. Primary diagnosis with confidence
|
| 328 |
emoji = "🔴" if conf_score >= 0.8 else "🟡" if conf_score >= 0.6 else "🟢"
|
| 329 |
response.append(f"{emoji} **Primary Finding:** {disease}")
|
| 330 |
response.append(f" Confidence: {conf_score:.0%}\n")
|
| 331 |
+
|
| 332 |
# 3. Critical safety alerts (if any)
|
| 333 |
critical_alerts = [a for a in alerts if a.get("severity") == "CRITICAL"]
|
| 334 |
if critical_alerts:
|
|
|
|
| 337 |
response.append(f" • {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}")
|
| 338 |
response.append(f" → {alert.get('action', 'Consult healthcare provider')}")
|
| 339 |
response.append("")
|
| 340 |
+
|
| 341 |
# 4. Key drivers explanation
|
| 342 |
key_drivers = prediction.get("key_drivers", [])
|
| 343 |
if key_drivers:
|
|
|
|
| 351 |
explanation = explanation[:147] + "..."
|
| 352 |
response.append(f" • **{biomarker}** ({value}): {explanation}")
|
| 353 |
response.append("")
|
| 354 |
+
|
| 355 |
# 5. What to do next (immediate actions)
|
| 356 |
immediate = recommendations.get("immediate_actions", [])
|
| 357 |
if immediate:
|
|
|
|
| 359 |
for i, action in enumerate(immediate[:3], 1):
|
| 360 |
response.append(f" {i}. {action}")
|
| 361 |
response.append("")
|
| 362 |
+
|
| 363 |
# 6. Lifestyle recommendations
|
| 364 |
lifestyle = recommendations.get("lifestyle_changes", [])
|
| 365 |
if lifestyle:
|
|
|
|
| 367 |
for i, change in enumerate(lifestyle[:3], 1):
|
| 368 |
response.append(f" {i}. {change}")
|
| 369 |
response.append("")
|
| 370 |
+
|
| 371 |
# 7. Disclaimer
|
| 372 |
response.append("ℹ️ **Important:** This is an AI-assisted analysis, NOT medical advice.")
|
| 373 |
response.append(" Please consult a healthcare professional for proper diagnosis and treatment.\n")
|
| 374 |
+
|
| 375 |
return "\n".join(response)
|
| 376 |
|
| 377 |
|
|
|
|
| 397 |
"""Run example diabetes patient case"""
|
| 398 |
print("\n📋 Running Example: Type 2 Diabetes Patient")
|
| 399 |
print(" 52-year-old male with elevated glucose and HbA1c\n")
|
| 400 |
+
|
| 401 |
example_biomarkers = {
|
| 402 |
"Glucose": 185.0,
|
| 403 |
"HbA1c": 8.2,
|
|
|
|
| 411 |
"Systolic Blood Pressure": 145,
|
| 412 |
"Diastolic Blood Pressure": 92
|
| 413 |
}
|
| 414 |
+
|
| 415 |
prediction = {
|
| 416 |
"disease": "Diabetes",
|
| 417 |
"confidence": 0.87,
|
|
|
|
| 423 |
"Thalassemia": 0.01
|
| 424 |
}
|
| 425 |
}
|
| 426 |
+
|
| 427 |
patient_input = PatientInput(
|
| 428 |
biomarkers=example_biomarkers,
|
| 429 |
model_prediction=prediction,
|
| 430 |
patient_context={"age": 52, "gender": "male", "bmi": 31.2}
|
| 431 |
)
|
| 432 |
+
|
| 433 |
print("🔄 Running analysis...\n")
|
| 434 |
result = guild.run(patient_input)
|
| 435 |
+
|
| 436 |
response = format_conversational(result.get("final_response", result), "there")
|
| 437 |
print("\n" + "="*70)
|
| 438 |
print("🤖 RAG-BOT:")
|
|
|
|
| 441 |
print("="*70 + "\n")
|
| 442 |
|
| 443 |
|
| 444 |
+
def save_report(result: dict, biomarkers: dict):
|
| 445 |
"""Save detailed JSON report to file"""
|
| 446 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 447 |
|
|
|
|
| 505 |
print(" 3. Type 'help' for biomarker list")
|
| 506 |
print(" 4. Type 'quit' to exit\n")
|
| 507 |
print("="*70 + "\n")
|
| 508 |
+
|
| 509 |
# Initialize guild (one-time setup)
|
| 510 |
print("🔧 Initializing medical knowledge system...")
|
| 511 |
try:
|
|
|
|
| 518 |
print(" • Vector store exists (run: python scripts/setup_embeddings.py)")
|
| 519 |
print(" • Internet connection is available for cloud LLM")
|
| 520 |
return
|
| 521 |
+
|
| 522 |
# Main conversation loop
|
| 523 |
conversation_history = []
|
| 524 |
user_name = "there"
|
| 525 |
+
|
| 526 |
while True:
|
| 527 |
try:
|
| 528 |
# Get user input
|
| 529 |
user_input = input("You: ").strip()
|
| 530 |
+
|
| 531 |
if not user_input:
|
| 532 |
continue
|
| 533 |
+
|
| 534 |
# Handle special commands
|
| 535 |
if user_input.lower() in ['quit', 'exit', 'q']:
|
| 536 |
print("\n👋 Thank you for using MediGuard AI. Stay healthy!")
|
| 537 |
break
|
| 538 |
+
|
| 539 |
if user_input.lower() == 'help':
|
| 540 |
print_biomarker_help()
|
| 541 |
continue
|
| 542 |
+
|
| 543 |
if user_input.lower() == 'example':
|
| 544 |
run_example_case(guild)
|
| 545 |
continue
|
| 546 |
+
|
| 547 |
# Extract biomarkers from natural language
|
| 548 |
print("\n🔍 Analyzing your input...")
|
| 549 |
biomarkers, patient_context = extract_biomarkers(user_input)
|
| 550 |
+
|
| 551 |
if not biomarkers:
|
| 552 |
print("❌ I couldn't find any biomarker values in your message.")
|
| 553 |
print(" Try: 'My glucose is 140 and HbA1c is 7.5'")
|
| 554 |
print(" Or type 'help' to see all biomarkers I can analyze.\n")
|
| 555 |
continue
|
| 556 |
+
|
| 557 |
print(f"✅ Found {len(biomarkers)} biomarker(s): {', '.join(biomarkers.keys())}")
|
| 558 |
+
|
| 559 |
# Check if we have enough biomarkers (minimum 2)
|
| 560 |
if len(biomarkers) < 2:
|
| 561 |
print("⚠️ I need at least 2 biomarkers for a reliable analysis.")
|
| 562 |
print(" Can you provide more values?\n")
|
| 563 |
continue
|
| 564 |
+
|
| 565 |
# Generate disease prediction
|
| 566 |
print("🧠 Predicting likely condition...")
|
| 567 |
prediction = predict_disease_llm(biomarkers, patient_context)
|
| 568 |
print(f"✅ Predicted: {prediction['disease']} ({prediction['confidence']:.0%} confidence)")
|
| 569 |
+
|
| 570 |
# Create PatientInput
|
| 571 |
patient_input = PatientInput(
|
| 572 |
biomarkers=biomarkers,
|
| 573 |
model_prediction=prediction,
|
| 574 |
patient_context=patient_context if patient_context else {"source": "chat"}
|
| 575 |
)
|
| 576 |
+
|
| 577 |
# Run full RAG workflow
|
| 578 |
print("📚 Consulting medical knowledge base...")
|
| 579 |
print(" (This may take 15-25 seconds...)\n")
|
| 580 |
+
|
| 581 |
result = guild.run(patient_input)
|
| 582 |
+
|
| 583 |
# Format conversational response
|
| 584 |
response = format_conversational(result.get("final_response", result), user_name)
|
| 585 |
+
|
| 586 |
# Display response
|
| 587 |
print("\n" + "="*70)
|
| 588 |
print("🤖 RAG-BOT:")
|
| 589 |
print("="*70)
|
| 590 |
print(response)
|
| 591 |
print("="*70 + "\n")
|
| 592 |
+
|
| 593 |
# Save to history
|
| 594 |
conversation_history.append({
|
| 595 |
"user_input": user_input,
|
|
|
|
| 597 |
"prediction": prediction,
|
| 598 |
"result": result
|
| 599 |
})
|
| 600 |
+
|
| 601 |
# Ask if user wants to save report
|
| 602 |
save_choice = input("💾 Save detailed report to file? (y/n): ").strip().lower()
|
| 603 |
if save_choice == 'y':
|
| 604 |
save_report(result, biomarkers)
|
| 605 |
+
|
| 606 |
print("\nYou can:")
|
| 607 |
print(" • Enter more biomarkers for a new analysis")
|
| 608 |
print(" • Type 'quit' to exit\n")
|
| 609 |
+
|
| 610 |
except KeyboardInterrupt:
|
| 611 |
print("\n\n👋 Interrupted. Thank you for using MediGuard AI!")
|
| 612 |
break
|
scripts/monitor_test.py
CHANGED
|
@@ -7,6 +7,6 @@ print("=" * 70)
|
|
| 7 |
for i in range(60): # Check for 5 minutes
|
| 8 |
time.sleep(5)
|
| 9 |
print(f"[{i*5}s] Test still running...")
|
| 10 |
-
|
| 11 |
print("\nTest should be complete or nearly complete.")
|
| 12 |
print("Check terminal output for results.")
|
|
|
|
| 7 |
for i in range(60): # Check for 5 minutes
|
| 8 |
time.sleep(5)
|
| 9 |
print(f"[{i*5}s] Test still running...")
|
| 10 |
+
|
| 11 |
print("\nTest should be complete or nearly complete.")
|
| 12 |
print("Check terminal output for results.")
|
scripts/setup_embeddings.py
CHANGED
|
@@ -2,22 +2,22 @@
|
|
| 2 |
Quick script to help set up Google API key for fast embeddings
|
| 3 |
"""
|
| 4 |
|
| 5 |
-
import os
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
|
|
| 8 |
def setup_google_api_key():
|
| 9 |
"""Interactive setup for Google API key"""
|
| 10 |
-
|
| 11 |
print("="*70)
|
| 12 |
print("Fast Embeddings Setup - Google Gemini API")
|
| 13 |
print("="*70)
|
| 14 |
-
|
| 15 |
print("\nWhy Google Gemini?")
|
| 16 |
print(" - 100x faster than local Ollama (2 mins vs 30+ mins)")
|
| 17 |
print(" - FREE for standard usage")
|
| 18 |
print(" - High quality embeddings")
|
| 19 |
print(" - Automatic fallback to Ollama if unavailable")
|
| 20 |
-
|
| 21 |
print("\n" + "="*70)
|
| 22 |
print("Step 1: Get Your Free API Key")
|
| 23 |
print("="*70)
|
|
@@ -26,28 +26,28 @@ def setup_google_api_key():
|
|
| 26 |
print("\n2. Sign in with Google account")
|
| 27 |
print("3. Click 'Create API Key'")
|
| 28 |
print("4. Copy the key (starts with 'AIza...')")
|
| 29 |
-
|
| 30 |
input("\nPress ENTER when you have your API key ready...")
|
| 31 |
-
|
| 32 |
api_key = input("\nPaste your Google API key here: ").strip()
|
| 33 |
-
|
| 34 |
if not api_key:
|
| 35 |
print("\nNo API key provided. Using local Ollama instead.")
|
| 36 |
return False
|
| 37 |
-
|
| 38 |
if not api_key.startswith("AIza"):
|
| 39 |
print("\nWarning: Key doesn't start with 'AIza'. Are you sure this is correct?")
|
| 40 |
confirm = input("Continue anyway? (y/n): ").strip().lower()
|
| 41 |
if confirm != 'y':
|
| 42 |
return False
|
| 43 |
-
|
| 44 |
# Update .env file
|
| 45 |
env_path = Path(".env")
|
| 46 |
-
|
| 47 |
if env_path.exists():
|
| 48 |
-
with open(env_path
|
| 49 |
lines = f.readlines()
|
| 50 |
-
|
| 51 |
# Update or add GOOGLE_API_KEY
|
| 52 |
updated = False
|
| 53 |
for i, line in enumerate(lines):
|
|
@@ -55,17 +55,17 @@ def setup_google_api_key():
|
|
| 55 |
lines[i] = f'GOOGLE_API_KEY={api_key}\n'
|
| 56 |
updated = True
|
| 57 |
break
|
| 58 |
-
|
| 59 |
if not updated:
|
| 60 |
lines.insert(0, f'GOOGLE_API_KEY={api_key}\n')
|
| 61 |
-
|
| 62 |
with open(env_path, 'w') as f:
|
| 63 |
f.writelines(lines)
|
| 64 |
else:
|
| 65 |
# Create new .env file
|
| 66 |
with open(env_path, 'w') as f:
|
| 67 |
f.write(f'GOOGLE_API_KEY={api_key}\n')
|
| 68 |
-
|
| 69 |
print("\nAPI key saved to .env file!")
|
| 70 |
print("\n" + "="*70)
|
| 71 |
print("Step 2: Build Vector Store")
|
|
@@ -74,7 +74,7 @@ def setup_google_api_key():
|
|
| 74 |
print(" python src/pdf_processor.py")
|
| 75 |
print("\nChoose option 1 (Google Gemini) when prompted.")
|
| 76 |
print("\n" + "="*70)
|
| 77 |
-
|
| 78 |
return True
|
| 79 |
|
| 80 |
|
|
|
|
| 2 |
Quick script to help set up Google API key for fast embeddings
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
+
|
| 8 |
def setup_google_api_key():
|
| 9 |
"""Interactive setup for Google API key"""
|
| 10 |
+
|
| 11 |
print("="*70)
|
| 12 |
print("Fast Embeddings Setup - Google Gemini API")
|
| 13 |
print("="*70)
|
| 14 |
+
|
| 15 |
print("\nWhy Google Gemini?")
|
| 16 |
print(" - 100x faster than local Ollama (2 mins vs 30+ mins)")
|
| 17 |
print(" - FREE for standard usage")
|
| 18 |
print(" - High quality embeddings")
|
| 19 |
print(" - Automatic fallback to Ollama if unavailable")
|
| 20 |
+
|
| 21 |
print("\n" + "="*70)
|
| 22 |
print("Step 1: Get Your Free API Key")
|
| 23 |
print("="*70)
|
|
|
|
| 26 |
print("\n2. Sign in with Google account")
|
| 27 |
print("3. Click 'Create API Key'")
|
| 28 |
print("4. Copy the key (starts with 'AIza...')")
|
| 29 |
+
|
| 30 |
input("\nPress ENTER when you have your API key ready...")
|
| 31 |
+
|
| 32 |
api_key = input("\nPaste your Google API key here: ").strip()
|
| 33 |
+
|
| 34 |
if not api_key:
|
| 35 |
print("\nNo API key provided. Using local Ollama instead.")
|
| 36 |
return False
|
| 37 |
+
|
| 38 |
if not api_key.startswith("AIza"):
|
| 39 |
print("\nWarning: Key doesn't start with 'AIza'. Are you sure this is correct?")
|
| 40 |
confirm = input("Continue anyway? (y/n): ").strip().lower()
|
| 41 |
if confirm != 'y':
|
| 42 |
return False
|
| 43 |
+
|
| 44 |
# Update .env file
|
| 45 |
env_path = Path(".env")
|
| 46 |
+
|
| 47 |
if env_path.exists():
|
| 48 |
+
with open(env_path) as f:
|
| 49 |
lines = f.readlines()
|
| 50 |
+
|
| 51 |
# Update or add GOOGLE_API_KEY
|
| 52 |
updated = False
|
| 53 |
for i, line in enumerate(lines):
|
|
|
|
| 55 |
lines[i] = f'GOOGLE_API_KEY={api_key}\n'
|
| 56 |
updated = True
|
| 57 |
break
|
| 58 |
+
|
| 59 |
if not updated:
|
| 60 |
lines.insert(0, f'GOOGLE_API_KEY={api_key}\n')
|
| 61 |
+
|
| 62 |
with open(env_path, 'w') as f:
|
| 63 |
f.writelines(lines)
|
| 64 |
else:
|
| 65 |
# Create new .env file
|
| 66 |
with open(env_path, 'w') as f:
|
| 67 |
f.write(f'GOOGLE_API_KEY={api_key}\n')
|
| 68 |
+
|
| 69 |
print("\nAPI key saved to .env file!")
|
| 70 |
print("\n" + "="*70)
|
| 71 |
print("Step 2: Build Vector Store")
|
|
|
|
| 74 |
print(" python src/pdf_processor.py")
|
| 75 |
print("\nChoose option 1 (Google Gemini) when prompted.")
|
| 76 |
print("\n" + "="*70)
|
| 77 |
+
|
| 78 |
return True
|
| 79 |
|
| 80 |
|
scripts/test_chat_demo.py
CHANGED
|
@@ -4,7 +4,6 @@ Quick demo script to test the chatbot with pre-defined inputs
|
|
| 4 |
|
| 5 |
import subprocess
|
| 6 |
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
|
| 9 |
# Test inputs
|
| 10 |
test_cases = [
|
|
@@ -36,16 +35,16 @@ try:
|
|
| 36 |
encoding='utf-8',
|
| 37 |
errors='replace'
|
| 38 |
)
|
| 39 |
-
|
| 40 |
print("STDOUT:")
|
| 41 |
print(result.stdout)
|
| 42 |
-
|
| 43 |
if result.stderr:
|
| 44 |
print("\nSTDERR:")
|
| 45 |
print(result.stderr)
|
| 46 |
-
|
| 47 |
print(f"\nExit code: {result.returncode}")
|
| 48 |
-
|
| 49 |
except subprocess.TimeoutExpired:
|
| 50 |
print("⚠️ Test timed out after 120 seconds")
|
| 51 |
except Exception as e:
|
|
|
|
| 4 |
|
| 5 |
import subprocess
|
| 6 |
import sys
|
|
|
|
| 7 |
|
| 8 |
# Test inputs
|
| 9 |
test_cases = [
|
|
|
|
| 35 |
encoding='utf-8',
|
| 36 |
errors='replace'
|
| 37 |
)
|
| 38 |
+
|
| 39 |
print("STDOUT:")
|
| 40 |
print(result.stdout)
|
| 41 |
+
|
| 42 |
if result.stderr:
|
| 43 |
print("\nSTDERR:")
|
| 44 |
print(result.stderr)
|
| 45 |
+
|
| 46 |
print(f"\nExit code: {result.returncode}")
|
| 47 |
+
|
| 48 |
except subprocess.TimeoutExpired:
|
| 49 |
print("⚠️ Test timed out after 120 seconds")
|
| 50 |
except Exception as e:
|
scripts/test_extraction.py
CHANGED
|
@@ -4,6 +4,7 @@ Quick test to verify biomarker extraction is working
|
|
| 4 |
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
|
|
|
| 7 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 8 |
|
| 9 |
from scripts.chat import extract_biomarkers, predict_disease_llm
|
|
@@ -22,25 +23,25 @@ print("="*70)
|
|
| 22 |
for i, test_input in enumerate(test_inputs, 1):
|
| 23 |
print(f"\n[Test {i}] Input: '{test_input}'")
|
| 24 |
print("-"*70)
|
| 25 |
-
|
| 26 |
biomarkers, context = extract_biomarkers(test_input)
|
| 27 |
-
|
| 28 |
if biomarkers:
|
| 29 |
print(f"✅ SUCCESS: Found {len(biomarkers)} biomarkers")
|
| 30 |
for name, value in biomarkers.items():
|
| 31 |
print(f" - {name}: {value}")
|
| 32 |
-
|
| 33 |
if context:
|
| 34 |
print(f" Context: {context}")
|
| 35 |
-
|
| 36 |
# Test prediction
|
| 37 |
print("\n Testing prediction...")
|
| 38 |
prediction = predict_disease_llm(biomarkers, context)
|
| 39 |
print(f" Predicted: {prediction['disease']} ({prediction['confidence']:.0%})")
|
| 40 |
-
|
| 41 |
else:
|
| 42 |
-
print(
|
| 43 |
-
|
| 44 |
print()
|
| 45 |
|
| 46 |
print("="*70)
|
|
|
|
| 4 |
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
+
|
| 8 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 9 |
|
| 10 |
from scripts.chat import extract_biomarkers, predict_disease_llm
|
|
|
|
| 23 |
for i, test_input in enumerate(test_inputs, 1):
|
| 24 |
print(f"\n[Test {i}] Input: '{test_input}'")
|
| 25 |
print("-"*70)
|
| 26 |
+
|
| 27 |
biomarkers, context = extract_biomarkers(test_input)
|
| 28 |
+
|
| 29 |
if biomarkers:
|
| 30 |
print(f"✅ SUCCESS: Found {len(biomarkers)} biomarkers")
|
| 31 |
for name, value in biomarkers.items():
|
| 32 |
print(f" - {name}: {value}")
|
| 33 |
+
|
| 34 |
if context:
|
| 35 |
print(f" Context: {context}")
|
| 36 |
+
|
| 37 |
# Test prediction
|
| 38 |
print("\n Testing prediction...")
|
| 39 |
prediction = predict_disease_llm(biomarkers, context)
|
| 40 |
print(f" Predicted: {prediction['disease']} ({prediction['confidence']:.0%})")
|
| 41 |
+
|
| 42 |
else:
|
| 43 |
+
print("❌ FAILED: No biomarkers extracted")
|
| 44 |
+
|
| 45 |
print()
|
| 46 |
|
| 47 |
print("="*70)
|
src/agents/biomarker_analyzer.py
CHANGED
|
@@ -3,19 +3,19 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Biomarker Analyzer Agent - Validates biomarker values and flags anomalies
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
|
| 7 |
-
from src.state import GuildState, AgentOutput, BiomarkerFlag
|
| 8 |
from src.biomarker_validator import BiomarkerValidator
|
| 9 |
from src.llm_config import llm_config
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class BiomarkerAnalyzerAgent:
|
| 13 |
"""Agent that validates biomarker values and generates comprehensive analysis"""
|
| 14 |
-
|
| 15 |
def __init__(self):
|
| 16 |
self.validator = BiomarkerValidator()
|
| 17 |
self.llm = llm_config.analyzer
|
| 18 |
-
|
| 19 |
def analyze(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Main agent function to analyze biomarkers.
|
|
@@ -29,12 +29,12 @@ class BiomarkerAnalyzerAgent:
|
|
| 29 |
print("\n" + "="*70)
|
| 30 |
print("EXECUTING: Biomarker Analyzer Agent")
|
| 31 |
print("="*70)
|
| 32 |
-
|
| 33 |
biomarkers = state['patient_biomarkers']
|
| 34 |
patient_context = state.get('patient_context', {})
|
| 35 |
gender = patient_context.get('gender') # None if not provided — uses non-gender-specific ranges
|
| 36 |
predicted_disease = state['model_prediction']['disease']
|
| 37 |
-
|
| 38 |
# Validate all biomarkers
|
| 39 |
print(f"\nValidating {len(biomarkers)} biomarkers...")
|
| 40 |
flags, alerts = self.validator.validate_all(
|
|
@@ -42,13 +42,13 @@ class BiomarkerAnalyzerAgent:
|
|
| 42 |
gender=gender,
|
| 43 |
threshold_pct=state['sop'].biomarker_analyzer_threshold
|
| 44 |
)
|
| 45 |
-
|
| 46 |
# Get disease-relevant biomarkers
|
| 47 |
relevant_biomarkers = self.validator.get_disease_relevant_biomarkers(predicted_disease)
|
| 48 |
-
|
| 49 |
# Generate summary using LLM
|
| 50 |
summary = self._generate_summary(biomarkers, flags, alerts, relevant_biomarkers, predicted_disease)
|
| 51 |
-
|
| 52 |
findings = {
|
| 53 |
"biomarker_flags": [flag.model_dump() for flag in flags],
|
| 54 |
"safety_alerts": [alert.model_dump() for alert in alerts],
|
|
@@ -62,35 +62,35 @@ class BiomarkerAnalyzerAgent:
|
|
| 62 |
agent_name="Biomarker Analyzer",
|
| 63 |
findings=findings
|
| 64 |
)
|
| 65 |
-
|
| 66 |
# Update state
|
| 67 |
print("\nAnalysis complete:")
|
| 68 |
print(f" - {len(flags)} biomarkers validated")
|
| 69 |
print(f" - {len([f for f in flags if f.status != 'NORMAL'])} out-of-range values")
|
| 70 |
print(f" - {len(alerts)} safety alerts generated")
|
| 71 |
print(f" - {len(relevant_biomarkers)} disease-relevant biomarkers identified")
|
| 72 |
-
|
| 73 |
return {
|
| 74 |
'agent_outputs': [output],
|
| 75 |
'biomarker_flags': flags,
|
| 76 |
'safety_alerts': alerts,
|
| 77 |
'biomarker_analysis': findings
|
| 78 |
}
|
| 79 |
-
|
| 80 |
def _generate_summary(
|
| 81 |
self,
|
| 82 |
-
biomarkers:
|
| 83 |
-
flags:
|
| 84 |
-
alerts:
|
| 85 |
-
relevant_biomarkers:
|
| 86 |
disease: str
|
| 87 |
) -> str:
|
| 88 |
"""Generate a concise summary of biomarker findings"""
|
| 89 |
-
|
| 90 |
# Count anomalies
|
| 91 |
critical = [f for f in flags if 'CRITICAL' in f.status]
|
| 92 |
high_low = [f for f in flags if f.status in ['HIGH', 'LOW']]
|
| 93 |
-
|
| 94 |
prompt = f"""You are a medical data analyst. Provide a brief, clinical summary of these biomarker results.
|
| 95 |
|
| 96 |
**Patient Context:**
|
|
@@ -115,24 +115,24 @@ Keep it concise and clinical."""
|
|
| 115 |
except Exception as e:
|
| 116 |
print(f"Warning: LLM summary generation failed: {e}")
|
| 117 |
return f"Biomarker analysis complete. {len(critical)} critical values, {len(high_low)} out-of-range values detected."
|
| 118 |
-
|
| 119 |
def _format_key_findings(self, critical, high_low, relevant):
|
| 120 |
"""Format findings for LLM prompt"""
|
| 121 |
findings = []
|
| 122 |
-
|
| 123 |
if critical:
|
| 124 |
findings.append("CRITICAL VALUES:")
|
| 125 |
for f in critical[:3]: # Top 3
|
| 126 |
findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
|
| 127 |
-
|
| 128 |
if high_low:
|
| 129 |
findings.append("\nOUT-OF-RANGE VALUES:")
|
| 130 |
for f in high_low[:5]: # Top 5
|
| 131 |
findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
|
| 132 |
-
|
| 133 |
if relevant:
|
| 134 |
findings.append(f"\nDISEASE-RELEVANT BIOMARKERS: {', '.join(relevant[:5])}")
|
| 135 |
-
|
| 136 |
return "\n".join(findings) if findings else "All biomarkers within normal range."
|
| 137 |
|
| 138 |
|
|
|
|
| 3 |
Biomarker Analyzer Agent - Validates biomarker values and flags anomalies
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
|
|
|
|
| 7 |
from src.biomarker_validator import BiomarkerValidator
|
| 8 |
from src.llm_config import llm_config
|
| 9 |
+
from src.state import AgentOutput, BiomarkerFlag, GuildState
|
| 10 |
|
| 11 |
|
| 12 |
class BiomarkerAnalyzerAgent:
|
| 13 |
"""Agent that validates biomarker values and generates comprehensive analysis"""
|
| 14 |
+
|
| 15 |
def __init__(self):
|
| 16 |
self.validator = BiomarkerValidator()
|
| 17 |
self.llm = llm_config.analyzer
|
| 18 |
+
|
| 19 |
def analyze(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Main agent function to analyze biomarkers.
|
|
|
|
| 29 |
print("\n" + "="*70)
|
| 30 |
print("EXECUTING: Biomarker Analyzer Agent")
|
| 31 |
print("="*70)
|
| 32 |
+
|
| 33 |
biomarkers = state['patient_biomarkers']
|
| 34 |
patient_context = state.get('patient_context', {})
|
| 35 |
gender = patient_context.get('gender') # None if not provided — uses non-gender-specific ranges
|
| 36 |
predicted_disease = state['model_prediction']['disease']
|
| 37 |
+
|
| 38 |
# Validate all biomarkers
|
| 39 |
print(f"\nValidating {len(biomarkers)} biomarkers...")
|
| 40 |
flags, alerts = self.validator.validate_all(
|
|
|
|
| 42 |
gender=gender,
|
| 43 |
threshold_pct=state['sop'].biomarker_analyzer_threshold
|
| 44 |
)
|
| 45 |
+
|
| 46 |
# Get disease-relevant biomarkers
|
| 47 |
relevant_biomarkers = self.validator.get_disease_relevant_biomarkers(predicted_disease)
|
| 48 |
+
|
| 49 |
# Generate summary using LLM
|
| 50 |
summary = self._generate_summary(biomarkers, flags, alerts, relevant_biomarkers, predicted_disease)
|
| 51 |
+
|
| 52 |
findings = {
|
| 53 |
"biomarker_flags": [flag.model_dump() for flag in flags],
|
| 54 |
"safety_alerts": [alert.model_dump() for alert in alerts],
|
|
|
|
| 62 |
agent_name="Biomarker Analyzer",
|
| 63 |
findings=findings
|
| 64 |
)
|
| 65 |
+
|
| 66 |
# Update state
|
| 67 |
print("\nAnalysis complete:")
|
| 68 |
print(f" - {len(flags)} biomarkers validated")
|
| 69 |
print(f" - {len([f for f in flags if f.status != 'NORMAL'])} out-of-range values")
|
| 70 |
print(f" - {len(alerts)} safety alerts generated")
|
| 71 |
print(f" - {len(relevant_biomarkers)} disease-relevant biomarkers identified")
|
| 72 |
+
|
| 73 |
return {
|
| 74 |
'agent_outputs': [output],
|
| 75 |
'biomarker_flags': flags,
|
| 76 |
'safety_alerts': alerts,
|
| 77 |
'biomarker_analysis': findings
|
| 78 |
}
|
| 79 |
+
|
| 80 |
def _generate_summary(
|
| 81 |
self,
|
| 82 |
+
biomarkers: dict[str, float],
|
| 83 |
+
flags: list[BiomarkerFlag],
|
| 84 |
+
alerts: list,
|
| 85 |
+
relevant_biomarkers: list[str],
|
| 86 |
disease: str
|
| 87 |
) -> str:
|
| 88 |
"""Generate a concise summary of biomarker findings"""
|
| 89 |
+
|
| 90 |
# Count anomalies
|
| 91 |
critical = [f for f in flags if 'CRITICAL' in f.status]
|
| 92 |
high_low = [f for f in flags if f.status in ['HIGH', 'LOW']]
|
| 93 |
+
|
| 94 |
prompt = f"""You are a medical data analyst. Provide a brief, clinical summary of these biomarker results.
|
| 95 |
|
| 96 |
**Patient Context:**
|
|
|
|
| 115 |
except Exception as e:
|
| 116 |
print(f"Warning: LLM summary generation failed: {e}")
|
| 117 |
return f"Biomarker analysis complete. {len(critical)} critical values, {len(high_low)} out-of-range values detected."
|
| 118 |
+
|
| 119 |
def _format_key_findings(self, critical, high_low, relevant):
|
| 120 |
"""Format findings for LLM prompt"""
|
| 121 |
findings = []
|
| 122 |
+
|
| 123 |
if critical:
|
| 124 |
findings.append("CRITICAL VALUES:")
|
| 125 |
for f in critical[:3]: # Top 3
|
| 126 |
findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
|
| 127 |
+
|
| 128 |
if high_low:
|
| 129 |
findings.append("\nOUT-OF-RANGE VALUES:")
|
| 130 |
for f in high_low[:5]: # Top 5
|
| 131 |
findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
|
| 132 |
+
|
| 133 |
if relevant:
|
| 134 |
findings.append(f"\nDISEASE-RELEVANT BIOMARKERS: {', '.join(relevant[:5])}")
|
| 135 |
+
|
| 136 |
return "\n".join(findings) if findings else "All biomarkers within normal range."
|
| 137 |
|
| 138 |
|
src/agents/biomarker_linker.py
CHANGED
|
@@ -3,15 +3,15 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Biomarker-Disease Linker Agent - Connects biomarker values to predicted disease
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
from src.llm_config import llm_config
|
| 9 |
-
from
|
| 10 |
|
| 11 |
|
| 12 |
class BiomarkerDiseaseLinkerAgent:
|
| 13 |
"""Agent that links specific biomarker values to the predicted disease"""
|
| 14 |
-
|
| 15 |
def __init__(self, retriever):
|
| 16 |
"""
|
| 17 |
Initialize with a retriever for biomarker-disease connections.
|
|
@@ -21,7 +21,7 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 21 |
"""
|
| 22 |
self.retriever = retriever
|
| 23 |
self.llm = llm_config.explainer
|
| 24 |
-
|
| 25 |
def link(self, state: GuildState) -> GuildState:
|
| 26 |
"""
|
| 27 |
Link biomarkers to disease prediction.
|
|
@@ -35,14 +35,14 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 35 |
print("\n" + "="*70)
|
| 36 |
print("EXECUTING: Biomarker-Disease Linker Agent (RAG)")
|
| 37 |
print("="*70)
|
| 38 |
-
|
| 39 |
model_prediction = state['model_prediction']
|
| 40 |
disease = model_prediction['disease']
|
| 41 |
biomarkers = state['patient_biomarkers']
|
| 42 |
-
|
| 43 |
# Get biomarker analysis from previous agent
|
| 44 |
biomarker_analysis = state.get('biomarker_analysis') or {}
|
| 45 |
-
|
| 46 |
# Identify key drivers
|
| 47 |
print(f"\nIdentifying key drivers for {disease}...")
|
| 48 |
key_drivers, citations_missing = self._identify_key_drivers(
|
|
@@ -51,9 +51,9 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 51 |
biomarker_analysis,
|
| 52 |
state
|
| 53 |
)
|
| 54 |
-
|
| 55 |
print(f"Identified {len(key_drivers)} key biomarker drivers")
|
| 56 |
-
|
| 57 |
# Create agent output
|
| 58 |
output = AgentOutput(
|
| 59 |
agent_name="Biomarker-Disease Linker",
|
|
@@ -65,45 +65,45 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 65 |
"citations_missing": citations_missing
|
| 66 |
}
|
| 67 |
)
|
| 68 |
-
|
| 69 |
# Update state
|
| 70 |
print("\nBiomarker-disease linking complete")
|
| 71 |
-
|
| 72 |
return {'agent_outputs': [output]}
|
| 73 |
-
|
| 74 |
def _identify_key_drivers(
|
| 75 |
self,
|
| 76 |
disease: str,
|
| 77 |
-
biomarkers:
|
| 78 |
analysis: dict,
|
| 79 |
state: GuildState
|
| 80 |
-
) -> tuple[
|
| 81 |
"""Identify which biomarkers are driving the disease prediction"""
|
| 82 |
-
|
| 83 |
# Get out-of-range biomarkers from analysis
|
| 84 |
flags = analysis.get('biomarker_flags', [])
|
| 85 |
abnormal_biomarkers = [
|
| 86 |
-
f for f in flags
|
| 87 |
if f['status'] != 'NORMAL'
|
| 88 |
]
|
| 89 |
-
|
| 90 |
# Get disease-relevant biomarkers
|
| 91 |
relevant = analysis.get('relevant_biomarkers', [])
|
| 92 |
-
|
| 93 |
# Focus on biomarkers that are both abnormal AND disease-relevant
|
| 94 |
key_biomarkers = [
|
| 95 |
f for f in abnormal_biomarkers
|
| 96 |
if f['name'] in relevant
|
| 97 |
]
|
| 98 |
-
|
| 99 |
# If no key biomarkers found, use top abnormal ones
|
| 100 |
if not key_biomarkers:
|
| 101 |
key_biomarkers = abnormal_biomarkers[:5]
|
| 102 |
-
|
| 103 |
print(f" Analyzing {len(key_biomarkers)} key biomarkers...")
|
| 104 |
-
|
| 105 |
# Generate key drivers with evidence
|
| 106 |
-
key_drivers:
|
| 107 |
citations_missing = False
|
| 108 |
for biomarker_flag in key_biomarkers[:5]: # Top 5
|
| 109 |
driver, driver_missing = self._create_key_driver(
|
|
@@ -115,7 +115,7 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 115 |
citations_missing = citations_missing or driver_missing
|
| 116 |
|
| 117 |
return key_drivers, citations_missing
|
| 118 |
-
|
| 119 |
def _create_key_driver(
|
| 120 |
self,
|
| 121 |
biomarker_flag: dict,
|
|
@@ -123,15 +123,15 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 123 |
state: GuildState
|
| 124 |
) -> tuple[KeyDriver, bool]:
|
| 125 |
"""Create a KeyDriver object with evidence from RAG"""
|
| 126 |
-
|
| 127 |
name = biomarker_flag['name']
|
| 128 |
value = biomarker_flag['value']
|
| 129 |
unit = biomarker_flag['unit']
|
| 130 |
status = biomarker_flag['status']
|
| 131 |
-
|
| 132 |
# Retrieve evidence linking this biomarker to the disease
|
| 133 |
query = f"How does {name} relate to {disease}? What does {status} {name} indicate?"
|
| 134 |
-
|
| 135 |
citations_missing = False
|
| 136 |
try:
|
| 137 |
docs = self.retriever.invoke(query)
|
|
@@ -147,12 +147,12 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 147 |
evidence_text = f"{status} {name} may be related to {disease}."
|
| 148 |
contribution = "Unknown"
|
| 149 |
citations_missing = True
|
| 150 |
-
|
| 151 |
# Generate explanation using LLM
|
| 152 |
explanation = self._generate_explanation(
|
| 153 |
name, value, unit, status, disease, evidence_text
|
| 154 |
)
|
| 155 |
-
|
| 156 |
driver = KeyDriver(
|
| 157 |
biomarker=name,
|
| 158 |
value=value,
|
|
@@ -162,12 +162,12 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 162 |
)
|
| 163 |
|
| 164 |
return driver, citations_missing
|
| 165 |
-
|
| 166 |
def _extract_evidence(self, docs: list, biomarker: str, disease: str) -> str:
|
| 167 |
"""Extract relevant evidence from retrieved documents"""
|
| 168 |
if not docs:
|
| 169 |
return f"Limited evidence available for {biomarker} in {disease}."
|
| 170 |
-
|
| 171 |
# Combine relevant passages
|
| 172 |
evidence = []
|
| 173 |
for doc in docs[:2]: # Top 2 docs
|
|
@@ -175,17 +175,17 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 175 |
# Extract sentences mentioning the biomarker
|
| 176 |
sentences = content.split('.')
|
| 177 |
relevant_sentences = [
|
| 178 |
-
s.strip() for s in sentences
|
| 179 |
if biomarker.lower() in s.lower() or disease.lower() in s.lower()
|
| 180 |
]
|
| 181 |
evidence.extend(relevant_sentences[:2])
|
| 182 |
-
|
| 183 |
return ". ".join(evidence[:3]) + "." if evidence else content[:300]
|
| 184 |
-
|
| 185 |
def _estimate_contribution(self, biomarker_flag: dict, doc_count: int) -> str:
|
| 186 |
"""Estimate the contribution percentage (simplified)"""
|
| 187 |
status = biomarker_flag['status']
|
| 188 |
-
|
| 189 |
# Simple heuristic based on severity
|
| 190 |
if 'CRITICAL' in status:
|
| 191 |
base = 40
|
|
@@ -193,13 +193,13 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 193 |
base = 25
|
| 194 |
else:
|
| 195 |
base = 10
|
| 196 |
-
|
| 197 |
# Adjust based on evidence strength
|
| 198 |
evidence_boost = min(doc_count * 2, 15)
|
| 199 |
-
|
| 200 |
total = min(base + evidence_boost, 60)
|
| 201 |
return f"{total}%"
|
| 202 |
-
|
| 203 |
def _generate_explanation(
|
| 204 |
self,
|
| 205 |
biomarker: str,
|
|
@@ -210,7 +210,7 @@ class BiomarkerDiseaseLinkerAgent:
|
|
| 210 |
evidence: str
|
| 211 |
) -> str:
|
| 212 |
"""Generate patient-friendly explanation"""
|
| 213 |
-
|
| 214 |
prompt = f"""Explain in 1-2 sentences how this biomarker result relates to {disease}:
|
| 215 |
|
| 216 |
Biomarker: {biomarker}
|
|
@@ -220,11 +220,11 @@ Status: {status}
|
|
| 220 |
Medical Evidence: {evidence}
|
| 221 |
|
| 222 |
Write in patient-friendly language, explaining what this means for the diagnosis."""
|
| 223 |
-
|
| 224 |
try:
|
| 225 |
response = self.llm.invoke(prompt)
|
| 226 |
return response.content.strip()
|
| 227 |
-
except Exception
|
| 228 |
return f"{biomarker} at {value} {unit} is {status}, which may be associated with {disease}."
|
| 229 |
|
| 230 |
|
|
|
|
| 3 |
Biomarker-Disease Linker Agent - Connects biomarker values to predicted disease
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
|
| 7 |
+
|
| 8 |
from src.llm_config import llm_config
|
| 9 |
+
from src.state import AgentOutput, GuildState, KeyDriver
|
| 10 |
|
| 11 |
|
| 12 |
class BiomarkerDiseaseLinkerAgent:
|
| 13 |
"""Agent that links specific biomarker values to the predicted disease"""
|
| 14 |
+
|
| 15 |
def __init__(self, retriever):
|
| 16 |
"""
|
| 17 |
Initialize with a retriever for biomarker-disease connections.
|
|
|
|
| 21 |
"""
|
| 22 |
self.retriever = retriever
|
| 23 |
self.llm = llm_config.explainer
|
| 24 |
+
|
| 25 |
def link(self, state: GuildState) -> GuildState:
|
| 26 |
"""
|
| 27 |
Link biomarkers to disease prediction.
|
|
|
|
| 35 |
print("\n" + "="*70)
|
| 36 |
print("EXECUTING: Biomarker-Disease Linker Agent (RAG)")
|
| 37 |
print("="*70)
|
| 38 |
+
|
| 39 |
model_prediction = state['model_prediction']
|
| 40 |
disease = model_prediction['disease']
|
| 41 |
biomarkers = state['patient_biomarkers']
|
| 42 |
+
|
| 43 |
# Get biomarker analysis from previous agent
|
| 44 |
biomarker_analysis = state.get('biomarker_analysis') or {}
|
| 45 |
+
|
| 46 |
# Identify key drivers
|
| 47 |
print(f"\nIdentifying key drivers for {disease}...")
|
| 48 |
key_drivers, citations_missing = self._identify_key_drivers(
|
|
|
|
| 51 |
biomarker_analysis,
|
| 52 |
state
|
| 53 |
)
|
| 54 |
+
|
| 55 |
print(f"Identified {len(key_drivers)} key biomarker drivers")
|
| 56 |
+
|
| 57 |
# Create agent output
|
| 58 |
output = AgentOutput(
|
| 59 |
agent_name="Biomarker-Disease Linker",
|
|
|
|
| 65 |
"citations_missing": citations_missing
|
| 66 |
}
|
| 67 |
)
|
| 68 |
+
|
| 69 |
# Update state
|
| 70 |
print("\nBiomarker-disease linking complete")
|
| 71 |
+
|
| 72 |
return {'agent_outputs': [output]}
|
| 73 |
+
|
| 74 |
def _identify_key_drivers(
|
| 75 |
self,
|
| 76 |
disease: str,
|
| 77 |
+
biomarkers: dict[str, float],
|
| 78 |
analysis: dict,
|
| 79 |
state: GuildState
|
| 80 |
+
) -> tuple[list[KeyDriver], bool]:
|
| 81 |
"""Identify which biomarkers are driving the disease prediction"""
|
| 82 |
+
|
| 83 |
# Get out-of-range biomarkers from analysis
|
| 84 |
flags = analysis.get('biomarker_flags', [])
|
| 85 |
abnormal_biomarkers = [
|
| 86 |
+
f for f in flags
|
| 87 |
if f['status'] != 'NORMAL'
|
| 88 |
]
|
| 89 |
+
|
| 90 |
# Get disease-relevant biomarkers
|
| 91 |
relevant = analysis.get('relevant_biomarkers', [])
|
| 92 |
+
|
| 93 |
# Focus on biomarkers that are both abnormal AND disease-relevant
|
| 94 |
key_biomarkers = [
|
| 95 |
f for f in abnormal_biomarkers
|
| 96 |
if f['name'] in relevant
|
| 97 |
]
|
| 98 |
+
|
| 99 |
# If no key biomarkers found, use top abnormal ones
|
| 100 |
if not key_biomarkers:
|
| 101 |
key_biomarkers = abnormal_biomarkers[:5]
|
| 102 |
+
|
| 103 |
print(f" Analyzing {len(key_biomarkers)} key biomarkers...")
|
| 104 |
+
|
| 105 |
# Generate key drivers with evidence
|
| 106 |
+
key_drivers: list[KeyDriver] = []
|
| 107 |
citations_missing = False
|
| 108 |
for biomarker_flag in key_biomarkers[:5]: # Top 5
|
| 109 |
driver, driver_missing = self._create_key_driver(
|
|
|
|
| 115 |
citations_missing = citations_missing or driver_missing
|
| 116 |
|
| 117 |
return key_drivers, citations_missing
|
| 118 |
+
|
| 119 |
def _create_key_driver(
|
| 120 |
self,
|
| 121 |
biomarker_flag: dict,
|
|
|
|
| 123 |
state: GuildState
|
| 124 |
) -> tuple[KeyDriver, bool]:
|
| 125 |
"""Create a KeyDriver object with evidence from RAG"""
|
| 126 |
+
|
| 127 |
name = biomarker_flag['name']
|
| 128 |
value = biomarker_flag['value']
|
| 129 |
unit = biomarker_flag['unit']
|
| 130 |
status = biomarker_flag['status']
|
| 131 |
+
|
| 132 |
# Retrieve evidence linking this biomarker to the disease
|
| 133 |
query = f"How does {name} relate to {disease}? What does {status} {name} indicate?"
|
| 134 |
+
|
| 135 |
citations_missing = False
|
| 136 |
try:
|
| 137 |
docs = self.retriever.invoke(query)
|
|
|
|
| 147 |
evidence_text = f"{status} {name} may be related to {disease}."
|
| 148 |
contribution = "Unknown"
|
| 149 |
citations_missing = True
|
| 150 |
+
|
| 151 |
# Generate explanation using LLM
|
| 152 |
explanation = self._generate_explanation(
|
| 153 |
name, value, unit, status, disease, evidence_text
|
| 154 |
)
|
| 155 |
+
|
| 156 |
driver = KeyDriver(
|
| 157 |
biomarker=name,
|
| 158 |
value=value,
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
return driver, citations_missing
|
| 165 |
+
|
| 166 |
def _extract_evidence(self, docs: list, biomarker: str, disease: str) -> str:
|
| 167 |
"""Extract relevant evidence from retrieved documents"""
|
| 168 |
if not docs:
|
| 169 |
return f"Limited evidence available for {biomarker} in {disease}."
|
| 170 |
+
|
| 171 |
# Combine relevant passages
|
| 172 |
evidence = []
|
| 173 |
for doc in docs[:2]: # Top 2 docs
|
|
|
|
| 175 |
# Extract sentences mentioning the biomarker
|
| 176 |
sentences = content.split('.')
|
| 177 |
relevant_sentences = [
|
| 178 |
+
s.strip() for s in sentences
|
| 179 |
if biomarker.lower() in s.lower() or disease.lower() in s.lower()
|
| 180 |
]
|
| 181 |
evidence.extend(relevant_sentences[:2])
|
| 182 |
+
|
| 183 |
return ". ".join(evidence[:3]) + "." if evidence else content[:300]
|
| 184 |
+
|
| 185 |
def _estimate_contribution(self, biomarker_flag: dict, doc_count: int) -> str:
|
| 186 |
"""Estimate the contribution percentage (simplified)"""
|
| 187 |
status = biomarker_flag['status']
|
| 188 |
+
|
| 189 |
# Simple heuristic based on severity
|
| 190 |
if 'CRITICAL' in status:
|
| 191 |
base = 40
|
|
|
|
| 193 |
base = 25
|
| 194 |
else:
|
| 195 |
base = 10
|
| 196 |
+
|
| 197 |
# Adjust based on evidence strength
|
| 198 |
evidence_boost = min(doc_count * 2, 15)
|
| 199 |
+
|
| 200 |
total = min(base + evidence_boost, 60)
|
| 201 |
return f"{total}%"
|
| 202 |
+
|
| 203 |
def _generate_explanation(
|
| 204 |
self,
|
| 205 |
biomarker: str,
|
|
|
|
| 210 |
evidence: str
|
| 211 |
) -> str:
|
| 212 |
"""Generate patient-friendly explanation"""
|
| 213 |
+
|
| 214 |
prompt = f"""Explain in 1-2 sentences how this biomarker result relates to {disease}:
|
| 215 |
|
| 216 |
Biomarker: {biomarker}
|
|
|
|
| 220 |
Medical Evidence: {evidence}
|
| 221 |
|
| 222 |
Write in patient-friendly language, explaining what this means for the diagnosis."""
|
| 223 |
+
|
| 224 |
try:
|
| 225 |
response = self.llm.invoke(prompt)
|
| 226 |
return response.content.strip()
|
| 227 |
+
except Exception:
|
| 228 |
return f"{biomarker} at {value} {unit} is {status}, which may be associated with {disease}."
|
| 229 |
|
| 230 |
|
src/agents/clinical_guidelines.py
CHANGED
|
@@ -4,15 +4,16 @@ Clinical Guidelines Agent - Retrieves evidence-based recommendations
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
from src.state import GuildState, AgentOutput
|
| 9 |
-
from src.llm_config import llm_config
|
| 10 |
from langchain_core.prompts import ChatPromptTemplate
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class ClinicalGuidelinesAgent:
|
| 14 |
"""Agent that retrieves clinical guidelines and recommendations using RAG"""
|
| 15 |
-
|
| 16 |
def __init__(self, retriever):
|
| 17 |
"""
|
| 18 |
Initialize with a retriever for clinical guidelines.
|
|
@@ -22,7 +23,7 @@ class ClinicalGuidelinesAgent:
|
|
| 22 |
"""
|
| 23 |
self.retriever = retriever
|
| 24 |
self.llm = llm_config.explainer
|
| 25 |
-
|
| 26 |
def recommend(self, state: GuildState) -> GuildState:
|
| 27 |
"""
|
| 28 |
Retrieve clinical guidelines and generate recommendations.
|
|
@@ -36,25 +37,25 @@ class ClinicalGuidelinesAgent:
|
|
| 36 |
print("\n" + "="*70)
|
| 37 |
print("EXECUTING: Clinical Guidelines Agent (RAG)")
|
| 38 |
print("="*70)
|
| 39 |
-
|
| 40 |
model_prediction = state['model_prediction']
|
| 41 |
disease = model_prediction['disease']
|
| 42 |
confidence = model_prediction['confidence']
|
| 43 |
-
|
| 44 |
# Get biomarker analysis
|
| 45 |
biomarker_analysis = state.get('biomarker_analysis') or {}
|
| 46 |
safety_alerts = biomarker_analysis.get('safety_alerts', [])
|
| 47 |
-
|
| 48 |
# Retrieve guidelines
|
| 49 |
print(f"\nRetrieving clinical guidelines for {disease}...")
|
| 50 |
-
|
| 51 |
query = f"""What are the clinical practice guidelines for managing {disease}?
|
| 52 |
Include lifestyle modifications, monitoring recommendations, and when to seek medical care."""
|
| 53 |
-
|
| 54 |
docs = self.retriever.invoke(query)
|
| 55 |
-
|
| 56 |
print(f"Retrieved {len(docs)} guideline documents")
|
| 57 |
-
|
| 58 |
# Generate recommendations
|
| 59 |
if state['sop'].require_pdf_citations and not docs:
|
| 60 |
recommendations = {
|
|
@@ -73,7 +74,7 @@ class ClinicalGuidelinesAgent:
|
|
| 73 |
confidence,
|
| 74 |
state
|
| 75 |
)
|
| 76 |
-
|
| 77 |
# Create agent output
|
| 78 |
output = AgentOutput(
|
| 79 |
agent_name="Clinical Guidelines",
|
|
@@ -87,15 +88,15 @@ class ClinicalGuidelinesAgent:
|
|
| 87 |
"citations_missing": state['sop'].require_pdf_citations and not docs
|
| 88 |
}
|
| 89 |
)
|
| 90 |
-
|
| 91 |
# Update state
|
| 92 |
print("\nRecommendations generated")
|
| 93 |
print(f" - Immediate actions: {len(recommendations['immediate_actions'])}")
|
| 94 |
print(f" - Lifestyle changes: {len(recommendations['lifestyle_changes'])}")
|
| 95 |
print(f" - Monitoring recommendations: {len(recommendations['monitoring'])}")
|
| 96 |
-
|
| 97 |
return {'agent_outputs': [output]}
|
| 98 |
-
|
| 99 |
def _generate_recommendations(
|
| 100 |
self,
|
| 101 |
disease: str,
|
|
@@ -105,20 +106,20 @@ class ClinicalGuidelinesAgent:
|
|
| 105 |
state: GuildState
|
| 106 |
) -> dict:
|
| 107 |
"""Generate structured recommendations using LLM and guidelines"""
|
| 108 |
-
|
| 109 |
# Format retrieved guidelines
|
| 110 |
guidelines_context = "\n\n---\n\n".join([
|
| 111 |
f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
|
| 112 |
for doc in docs
|
| 113 |
])
|
| 114 |
-
|
| 115 |
# Build safety context
|
| 116 |
safety_context = ""
|
| 117 |
if safety_alerts:
|
| 118 |
safety_context = "\n**CRITICAL SAFETY ALERTS:**\n"
|
| 119 |
for alert in safety_alerts[:3]:
|
| 120 |
safety_context += f"- {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}\n"
|
| 121 |
-
|
| 122 |
prompt = ChatPromptTemplate.from_messages([
|
| 123 |
("system", """You are a clinical decision support system providing evidence-based recommendations.
|
| 124 |
Based on clinical practice guidelines, provide actionable recommendations for patient self-assessment.
|
|
@@ -139,9 +140,9 @@ class ClinicalGuidelinesAgent:
|
|
| 139 |
|
| 140 |
Please provide structured recommendations for patient self-assessment.""")
|
| 141 |
])
|
| 142 |
-
|
| 143 |
chain = prompt | self.llm
|
| 144 |
-
|
| 145 |
try:
|
| 146 |
response = chain.invoke({
|
| 147 |
"disease": disease,
|
|
@@ -149,18 +150,18 @@ class ClinicalGuidelinesAgent:
|
|
| 149 |
"safety_context": safety_context,
|
| 150 |
"guidelines": guidelines_context
|
| 151 |
})
|
| 152 |
-
|
| 153 |
recommendations = self._parse_recommendations(response.content)
|
| 154 |
-
|
| 155 |
except Exception as e:
|
| 156 |
print(f"Warning: LLM recommendation generation failed: {e}")
|
| 157 |
recommendations = self._get_default_recommendations(disease, safety_alerts)
|
| 158 |
-
|
| 159 |
# Add citations
|
| 160 |
recommendations['citations'] = self._extract_citations(docs)
|
| 161 |
-
|
| 162 |
return recommendations
|
| 163 |
-
|
| 164 |
def _parse_recommendations(self, content: str) -> dict:
|
| 165 |
"""Parse LLM response into structured recommendations"""
|
| 166 |
recommendations = {
|
|
@@ -168,14 +169,14 @@ class ClinicalGuidelinesAgent:
|
|
| 168 |
"lifestyle_changes": [],
|
| 169 |
"monitoring": []
|
| 170 |
}
|
| 171 |
-
|
| 172 |
current_section = None
|
| 173 |
lines = content.split('\n')
|
| 174 |
-
|
| 175 |
for line in lines:
|
| 176 |
line_stripped = line.strip()
|
| 177 |
line_upper = line_stripped.upper()
|
| 178 |
-
|
| 179 |
# Detect section headers
|
| 180 |
if 'IMMEDIATE' in line_upper or 'URGENT' in line_upper:
|
| 181 |
current_section = 'immediate_actions'
|
|
@@ -189,16 +190,16 @@ class ClinicalGuidelinesAgent:
|
|
| 189 |
cleaned = line_stripped.lstrip('•-*0123456789. ')
|
| 190 |
if cleaned and len(cleaned) > 10: # Minimum length filter
|
| 191 |
recommendations[current_section].append(cleaned)
|
| 192 |
-
|
| 193 |
# If parsing failed, create default structure
|
| 194 |
if not any(recommendations.values()):
|
| 195 |
sentences = content.split('.')
|
| 196 |
recommendations['immediate_actions'] = [s.strip() for s in sentences[:2] if s.strip()]
|
| 197 |
recommendations['lifestyle_changes'] = [s.strip() for s in sentences[2:4] if s.strip()]
|
| 198 |
recommendations['monitoring'] = [s.strip() for s in sentences[4:6] if s.strip()]
|
| 199 |
-
|
| 200 |
return recommendations
|
| 201 |
-
|
| 202 |
def _get_default_recommendations(self, disease: str, safety_alerts: list) -> dict:
|
| 203 |
"""Provide default recommendations if LLM fails"""
|
| 204 |
recommendations = {
|
|
@@ -206,7 +207,7 @@ class ClinicalGuidelinesAgent:
|
|
| 206 |
"lifestyle_changes": [],
|
| 207 |
"monitoring": []
|
| 208 |
}
|
| 209 |
-
|
| 210 |
# Add safety-based immediate actions
|
| 211 |
if safety_alerts:
|
| 212 |
recommendations['immediate_actions'].append(
|
|
@@ -219,36 +220,36 @@ class ClinicalGuidelinesAgent:
|
|
| 219 |
recommendations['immediate_actions'].append(
|
| 220 |
f"Schedule appointment with healthcare provider to discuss {disease} findings"
|
| 221 |
)
|
| 222 |
-
|
| 223 |
# Generic lifestyle changes
|
| 224 |
recommendations['lifestyle_changes'].extend([
|
| 225 |
"Follow a balanced, nutrient-rich diet as recommended by healthcare provider",
|
| 226 |
"Maintain regular physical activity appropriate for your health status",
|
| 227 |
"Track symptoms and biomarker trends over time"
|
| 228 |
])
|
| 229 |
-
|
| 230 |
# Generic monitoring
|
| 231 |
recommendations['monitoring'].extend([
|
| 232 |
f"Regular monitoring of {disease}-related biomarkers as advised by physician",
|
| 233 |
"Keep a health journal tracking symptoms, diet, and activities",
|
| 234 |
"Schedule follow-up appointments as recommended"
|
| 235 |
])
|
| 236 |
-
|
| 237 |
return recommendations
|
| 238 |
-
|
| 239 |
-
def _extract_citations(self, docs: list) ->
|
| 240 |
"""Extract citations from retrieved guideline documents"""
|
| 241 |
citations = []
|
| 242 |
-
|
| 243 |
for doc in docs:
|
| 244 |
source = doc.metadata.get('source', 'Unknown')
|
| 245 |
-
|
| 246 |
# Clean up source path
|
| 247 |
if '\\' in source or '/' in source:
|
| 248 |
source = Path(source).name
|
| 249 |
-
|
| 250 |
citations.append(source)
|
| 251 |
-
|
| 252 |
return list(set(citations)) # Remove duplicates
|
| 253 |
|
| 254 |
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from pathlib import Path
|
| 7 |
+
|
|
|
|
|
|
|
| 8 |
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
|
| 10 |
+
from src.llm_config import llm_config
|
| 11 |
+
from src.state import AgentOutput, GuildState
|
| 12 |
+
|
| 13 |
|
| 14 |
class ClinicalGuidelinesAgent:
|
| 15 |
"""Agent that retrieves clinical guidelines and recommendations using RAG"""
|
| 16 |
+
|
| 17 |
def __init__(self, retriever):
|
| 18 |
"""
|
| 19 |
Initialize with a retriever for clinical guidelines.
|
|
|
|
| 23 |
"""
|
| 24 |
self.retriever = retriever
|
| 25 |
self.llm = llm_config.explainer
|
| 26 |
+
|
| 27 |
def recommend(self, state: GuildState) -> GuildState:
|
| 28 |
"""
|
| 29 |
Retrieve clinical guidelines and generate recommendations.
|
|
|
|
| 37 |
print("\n" + "="*70)
|
| 38 |
print("EXECUTING: Clinical Guidelines Agent (RAG)")
|
| 39 |
print("="*70)
|
| 40 |
+
|
| 41 |
model_prediction = state['model_prediction']
|
| 42 |
disease = model_prediction['disease']
|
| 43 |
confidence = model_prediction['confidence']
|
| 44 |
+
|
| 45 |
# Get biomarker analysis
|
| 46 |
biomarker_analysis = state.get('biomarker_analysis') or {}
|
| 47 |
safety_alerts = biomarker_analysis.get('safety_alerts', [])
|
| 48 |
+
|
| 49 |
# Retrieve guidelines
|
| 50 |
print(f"\nRetrieving clinical guidelines for {disease}...")
|
| 51 |
+
|
| 52 |
query = f"""What are the clinical practice guidelines for managing {disease}?
|
| 53 |
Include lifestyle modifications, monitoring recommendations, and when to seek medical care."""
|
| 54 |
+
|
| 55 |
docs = self.retriever.invoke(query)
|
| 56 |
+
|
| 57 |
print(f"Retrieved {len(docs)} guideline documents")
|
| 58 |
+
|
| 59 |
# Generate recommendations
|
| 60 |
if state['sop'].require_pdf_citations and not docs:
|
| 61 |
recommendations = {
|
|
|
|
| 74 |
confidence,
|
| 75 |
state
|
| 76 |
)
|
| 77 |
+
|
| 78 |
# Create agent output
|
| 79 |
output = AgentOutput(
|
| 80 |
agent_name="Clinical Guidelines",
|
|
|
|
| 88 |
"citations_missing": state['sop'].require_pdf_citations and not docs
|
| 89 |
}
|
| 90 |
)
|
| 91 |
+
|
| 92 |
# Update state
|
| 93 |
print("\nRecommendations generated")
|
| 94 |
print(f" - Immediate actions: {len(recommendations['immediate_actions'])}")
|
| 95 |
print(f" - Lifestyle changes: {len(recommendations['lifestyle_changes'])}")
|
| 96 |
print(f" - Monitoring recommendations: {len(recommendations['monitoring'])}")
|
| 97 |
+
|
| 98 |
return {'agent_outputs': [output]}
|
| 99 |
+
|
| 100 |
def _generate_recommendations(
|
| 101 |
self,
|
| 102 |
disease: str,
|
|
|
|
| 106 |
state: GuildState
|
| 107 |
) -> dict:
|
| 108 |
"""Generate structured recommendations using LLM and guidelines"""
|
| 109 |
+
|
| 110 |
# Format retrieved guidelines
|
| 111 |
guidelines_context = "\n\n---\n\n".join([
|
| 112 |
f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
|
| 113 |
for doc in docs
|
| 114 |
])
|
| 115 |
+
|
| 116 |
# Build safety context
|
| 117 |
safety_context = ""
|
| 118 |
if safety_alerts:
|
| 119 |
safety_context = "\n**CRITICAL SAFETY ALERTS:**\n"
|
| 120 |
for alert in safety_alerts[:3]:
|
| 121 |
safety_context += f"- {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}\n"
|
| 122 |
+
|
| 123 |
prompt = ChatPromptTemplate.from_messages([
|
| 124 |
("system", """You are a clinical decision support system providing evidence-based recommendations.
|
| 125 |
Based on clinical practice guidelines, provide actionable recommendations for patient self-assessment.
|
|
|
|
| 140 |
|
| 141 |
Please provide structured recommendations for patient self-assessment.""")
|
| 142 |
])
|
| 143 |
+
|
| 144 |
chain = prompt | self.llm
|
| 145 |
+
|
| 146 |
try:
|
| 147 |
response = chain.invoke({
|
| 148 |
"disease": disease,
|
|
|
|
| 150 |
"safety_context": safety_context,
|
| 151 |
"guidelines": guidelines_context
|
| 152 |
})
|
| 153 |
+
|
| 154 |
recommendations = self._parse_recommendations(response.content)
|
| 155 |
+
|
| 156 |
except Exception as e:
|
| 157 |
print(f"Warning: LLM recommendation generation failed: {e}")
|
| 158 |
recommendations = self._get_default_recommendations(disease, safety_alerts)
|
| 159 |
+
|
| 160 |
# Add citations
|
| 161 |
recommendations['citations'] = self._extract_citations(docs)
|
| 162 |
+
|
| 163 |
return recommendations
|
| 164 |
+
|
| 165 |
def _parse_recommendations(self, content: str) -> dict:
|
| 166 |
"""Parse LLM response into structured recommendations"""
|
| 167 |
recommendations = {
|
|
|
|
| 169 |
"lifestyle_changes": [],
|
| 170 |
"monitoring": []
|
| 171 |
}
|
| 172 |
+
|
| 173 |
current_section = None
|
| 174 |
lines = content.split('\n')
|
| 175 |
+
|
| 176 |
for line in lines:
|
| 177 |
line_stripped = line.strip()
|
| 178 |
line_upper = line_stripped.upper()
|
| 179 |
+
|
| 180 |
# Detect section headers
|
| 181 |
if 'IMMEDIATE' in line_upper or 'URGENT' in line_upper:
|
| 182 |
current_section = 'immediate_actions'
|
|
|
|
| 190 |
cleaned = line_stripped.lstrip('•-*0123456789. ')
|
| 191 |
if cleaned and len(cleaned) > 10: # Minimum length filter
|
| 192 |
recommendations[current_section].append(cleaned)
|
| 193 |
+
|
| 194 |
# If parsing failed, create default structure
|
| 195 |
if not any(recommendations.values()):
|
| 196 |
sentences = content.split('.')
|
| 197 |
recommendations['immediate_actions'] = [s.strip() for s in sentences[:2] if s.strip()]
|
| 198 |
recommendations['lifestyle_changes'] = [s.strip() for s in sentences[2:4] if s.strip()]
|
| 199 |
recommendations['monitoring'] = [s.strip() for s in sentences[4:6] if s.strip()]
|
| 200 |
+
|
| 201 |
return recommendations
|
| 202 |
+
|
| 203 |
def _get_default_recommendations(self, disease: str, safety_alerts: list) -> dict:
|
| 204 |
"""Provide default recommendations if LLM fails"""
|
| 205 |
recommendations = {
|
|
|
|
| 207 |
"lifestyle_changes": [],
|
| 208 |
"monitoring": []
|
| 209 |
}
|
| 210 |
+
|
| 211 |
# Add safety-based immediate actions
|
| 212 |
if safety_alerts:
|
| 213 |
recommendations['immediate_actions'].append(
|
|
|
|
| 220 |
recommendations['immediate_actions'].append(
|
| 221 |
f"Schedule appointment with healthcare provider to discuss {disease} findings"
|
| 222 |
)
|
| 223 |
+
|
| 224 |
# Generic lifestyle changes
|
| 225 |
recommendations['lifestyle_changes'].extend([
|
| 226 |
"Follow a balanced, nutrient-rich diet as recommended by healthcare provider",
|
| 227 |
"Maintain regular physical activity appropriate for your health status",
|
| 228 |
"Track symptoms and biomarker trends over time"
|
| 229 |
])
|
| 230 |
+
|
| 231 |
# Generic monitoring
|
| 232 |
recommendations['monitoring'].extend([
|
| 233 |
f"Regular monitoring of {disease}-related biomarkers as advised by physician",
|
| 234 |
"Keep a health journal tracking symptoms, diet, and activities",
|
| 235 |
"Schedule follow-up appointments as recommended"
|
| 236 |
])
|
| 237 |
+
|
| 238 |
return recommendations
|
| 239 |
+
|
| 240 |
+
def _extract_citations(self, docs: list) -> list[str]:
|
| 241 |
"""Extract citations from retrieved guideline documents"""
|
| 242 |
citations = []
|
| 243 |
+
|
| 244 |
for doc in docs:
|
| 245 |
source = doc.metadata.get('source', 'Unknown')
|
| 246 |
+
|
| 247 |
# Clean up source path
|
| 248 |
if '\\' in source or '/' in source:
|
| 249 |
source = Path(source).name
|
| 250 |
+
|
| 251 |
citations.append(source)
|
| 252 |
+
|
| 253 |
return list(set(citations)) # Remove duplicates
|
| 254 |
|
| 255 |
|
src/agents/confidence_assessor.py
CHANGED
|
@@ -3,19 +3,19 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Confidence Assessor Agent - Evaluates prediction reliability
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
from typing import Any
|
| 7 |
-
|
| 8 |
from src.biomarker_validator import BiomarkerValidator
|
| 9 |
from src.llm_config import llm_config
|
| 10 |
-
from
|
| 11 |
|
| 12 |
|
| 13 |
class ConfidenceAssessorAgent:
|
| 14 |
"""Agent that assesses the reliability and limitations of the prediction"""
|
| 15 |
-
|
| 16 |
def __init__(self):
|
| 17 |
self.llm = llm_config.analyzer
|
| 18 |
-
|
| 19 |
def assess(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Assess prediction confidence and identify limitations.
|
|
@@ -29,41 +29,41 @@ class ConfidenceAssessorAgent:
|
|
| 29 |
print("\n" + "="*70)
|
| 30 |
print("EXECUTING: Confidence Assessor Agent")
|
| 31 |
print("="*70)
|
| 32 |
-
|
| 33 |
model_prediction = state['model_prediction']
|
| 34 |
disease = model_prediction['disease']
|
| 35 |
ml_confidence = model_prediction['confidence']
|
| 36 |
probabilities = model_prediction.get('probabilities', {})
|
| 37 |
biomarkers = state['patient_biomarkers']
|
| 38 |
-
|
| 39 |
# Collect previous agent findings
|
| 40 |
biomarker_analysis = state.get('biomarker_analysis') or {}
|
| 41 |
disease_explanation = self._get_agent_findings(state, "Disease Explainer")
|
| 42 |
linker_findings = self._get_agent_findings(state, "Biomarker-Disease Linker")
|
| 43 |
-
|
| 44 |
print(f"\nAssessing confidence for {disease} prediction...")
|
| 45 |
-
|
| 46 |
# Evaluate evidence strength
|
| 47 |
evidence_strength = self._evaluate_evidence_strength(
|
| 48 |
biomarker_analysis,
|
| 49 |
disease_explanation,
|
| 50 |
linker_findings
|
| 51 |
)
|
| 52 |
-
|
| 53 |
# Identify limitations
|
| 54 |
limitations = self._identify_limitations(
|
| 55 |
biomarkers,
|
| 56 |
biomarker_analysis,
|
| 57 |
probabilities
|
| 58 |
)
|
| 59 |
-
|
| 60 |
# Calculate aggregate reliability
|
| 61 |
reliability = self._calculate_reliability(
|
| 62 |
ml_confidence,
|
| 63 |
evidence_strength,
|
| 64 |
len(limitations)
|
| 65 |
)
|
| 66 |
-
|
| 67 |
# Generate assessment summary
|
| 68 |
assessment_summary = self._generate_assessment(
|
| 69 |
disease,
|
|
@@ -72,7 +72,7 @@ class ConfidenceAssessorAgent:
|
|
| 72 |
evidence_strength,
|
| 73 |
limitations
|
| 74 |
)
|
| 75 |
-
|
| 76 |
# Create agent output
|
| 77 |
output = AgentOutput(
|
| 78 |
agent_name="Confidence Assessor",
|
|
@@ -86,22 +86,22 @@ class ConfidenceAssessorAgent:
|
|
| 86 |
"alternative_diagnoses": self._get_alternatives(probabilities)
|
| 87 |
}
|
| 88 |
)
|
| 89 |
-
|
| 90 |
# Update state
|
| 91 |
print("\nConfidence assessment complete")
|
| 92 |
print(f" - Prediction reliability: {reliability}")
|
| 93 |
print(f" - Evidence strength: {evidence_strength}")
|
| 94 |
print(f" - Limitations identified: {len(limitations)}")
|
| 95 |
-
|
| 96 |
return {'agent_outputs': [output]}
|
| 97 |
-
|
| 98 |
def _get_agent_findings(self, state: GuildState, agent_name: str) -> dict:
|
| 99 |
"""Extract findings from a specific agent"""
|
| 100 |
for output in state.get('agent_outputs', []):
|
| 101 |
if output.agent_name == agent_name:
|
| 102 |
return output.findings
|
| 103 |
return {}
|
| 104 |
-
|
| 105 |
def _evaluate_evidence_strength(
|
| 106 |
self,
|
| 107 |
biomarker_analysis: dict,
|
|
@@ -109,10 +109,10 @@ class ConfidenceAssessorAgent:
|
|
| 109 |
linker_findings: dict
|
| 110 |
) -> str:
|
| 111 |
"""Evaluate the strength of supporting evidence"""
|
| 112 |
-
|
| 113 |
score = 0
|
| 114 |
max_score = 5
|
| 115 |
-
|
| 116 |
# Check biomarker validation quality
|
| 117 |
flags = biomarker_analysis.get('biomarker_flags', [])
|
| 118 |
abnormal_count = len([f for f in flags if f.get('status') != 'NORMAL'])
|
|
@@ -120,18 +120,18 @@ class ConfidenceAssessorAgent:
|
|
| 120 |
score += 1
|
| 121 |
if abnormal_count >= 5:
|
| 122 |
score += 1
|
| 123 |
-
|
| 124 |
# Check disease explanation quality
|
| 125 |
if disease_explanation.get('retrieval_quality', 0) >= 3:
|
| 126 |
score += 1
|
| 127 |
-
|
| 128 |
# Check biomarker-disease linking
|
| 129 |
key_drivers = linker_findings.get('key_drivers', [])
|
| 130 |
if len(key_drivers) >= 2:
|
| 131 |
score += 1
|
| 132 |
if len(key_drivers) >= 4:
|
| 133 |
score += 1
|
| 134 |
-
|
| 135 |
# Map score to categorical rating
|
| 136 |
if score >= 4:
|
| 137 |
return "STRONG"
|
|
@@ -139,22 +139,22 @@ class ConfidenceAssessorAgent:
|
|
| 139 |
return "MODERATE"
|
| 140 |
else:
|
| 141 |
return "WEAK"
|
| 142 |
-
|
| 143 |
def _identify_limitations(
|
| 144 |
self,
|
| 145 |
-
biomarkers:
|
| 146 |
biomarker_analysis: dict,
|
| 147 |
-
probabilities:
|
| 148 |
-
) ->
|
| 149 |
"""Identify limitations and uncertainties"""
|
| 150 |
limitations = []
|
| 151 |
-
|
| 152 |
# Check for missing biomarkers
|
| 153 |
expected_biomarkers = BiomarkerValidator().expected_biomarker_count()
|
| 154 |
if len(biomarkers) < expected_biomarkers:
|
| 155 |
missing = expected_biomarkers - len(biomarkers)
|
| 156 |
limitations.append(f"Missing data: {missing} biomarker(s) not provided")
|
| 157 |
-
|
| 158 |
# Check for close alternative predictions
|
| 159 |
sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
|
| 160 |
if len(sorted_probs) >= 2:
|
|
@@ -164,7 +164,7 @@ class ConfidenceAssessorAgent:
|
|
| 164 |
limitations.append(
|
| 165 |
f"Differential diagnosis: {top2} also possible ({prob2:.1%} probability)"
|
| 166 |
)
|
| 167 |
-
|
| 168 |
# Check for normal biomarkers despite prediction
|
| 169 |
flags = biomarker_analysis.get('biomarker_flags', [])
|
| 170 |
relevant = biomarker_analysis.get('relevant_biomarkers', [])
|
|
@@ -174,18 +174,18 @@ class ConfidenceAssessorAgent:
|
|
| 174 |
]
|
| 175 |
if len(normal_relevant) >= 2:
|
| 176 |
limitations.append(
|
| 177 |
-
|
| 178 |
)
|
| 179 |
-
|
| 180 |
# Check for safety alerts (indicates complexity)
|
| 181 |
alerts = biomarker_analysis.get('safety_alerts', [])
|
| 182 |
if len(alerts) >= 2:
|
| 183 |
limitations.append(
|
| 184 |
"Multiple critical values detected; professional evaluation essential"
|
| 185 |
)
|
| 186 |
-
|
| 187 |
return limitations
|
| 188 |
-
|
| 189 |
def _calculate_reliability(
|
| 190 |
self,
|
| 191 |
ml_confidence: float,
|
|
@@ -193,9 +193,9 @@ class ConfidenceAssessorAgent:
|
|
| 193 |
limitation_count: int
|
| 194 |
) -> str:
|
| 195 |
"""Calculate overall prediction reliability"""
|
| 196 |
-
|
| 197 |
score = 0
|
| 198 |
-
|
| 199 |
# ML confidence contribution
|
| 200 |
if ml_confidence >= 0.8:
|
| 201 |
score += 3
|
|
@@ -203,7 +203,7 @@ class ConfidenceAssessorAgent:
|
|
| 203 |
score += 2
|
| 204 |
elif ml_confidence >= 0.4:
|
| 205 |
score += 1
|
| 206 |
-
|
| 207 |
# Evidence strength contribution
|
| 208 |
if evidence_strength == "STRONG":
|
| 209 |
score += 3
|
|
@@ -211,10 +211,10 @@ class ConfidenceAssessorAgent:
|
|
| 211 |
score += 2
|
| 212 |
else:
|
| 213 |
score += 1
|
| 214 |
-
|
| 215 |
# Limitation penalty
|
| 216 |
score -= min(limitation_count, 3)
|
| 217 |
-
|
| 218 |
# Map to categorical
|
| 219 |
if score >= 5:
|
| 220 |
return "HIGH"
|
|
@@ -222,17 +222,17 @@ class ConfidenceAssessorAgent:
|
|
| 222 |
return "MODERATE"
|
| 223 |
else:
|
| 224 |
return "LOW"
|
| 225 |
-
|
| 226 |
def _generate_assessment(
|
| 227 |
self,
|
| 228 |
disease: str,
|
| 229 |
ml_confidence: float,
|
| 230 |
reliability: str,
|
| 231 |
evidence_strength: str,
|
| 232 |
-
limitations:
|
| 233 |
) -> str:
|
| 234 |
"""Generate human-readable assessment summary"""
|
| 235 |
-
|
| 236 |
prompt = f"""As a medical AI assessment system, provide a brief confidence statement about this prediction:
|
| 237 |
|
| 238 |
Disease Predicted: {disease}
|
|
@@ -254,7 +254,7 @@ Be honest about uncertainty. Patient safety is paramount."""
|
|
| 254 |
except Exception as e:
|
| 255 |
print(f"Warning: Assessment generation failed: {e}")
|
| 256 |
return f"The {disease} prediction has {reliability.lower()} reliability based on available data. Professional medical evaluation is strongly recommended for accurate diagnosis."
|
| 257 |
-
|
| 258 |
def _get_recommendation(self, reliability: str) -> str:
|
| 259 |
"""Get action recommendation based on reliability"""
|
| 260 |
if reliability == "HIGH":
|
|
@@ -263,11 +263,11 @@ Be honest about uncertainty. Patient safety is paramount."""
|
|
| 263 |
return "Moderate confidence prediction. Medical consultation recommended for professional evaluation and additional testing if needed."
|
| 264 |
else:
|
| 265 |
return "Low confidence prediction. Professional medical assessment essential. Additional tests may be required for accurate diagnosis."
|
| 266 |
-
|
| 267 |
-
def _get_alternatives(self, probabilities:
|
| 268 |
"""Get alternative diagnoses to consider"""
|
| 269 |
sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
|
| 270 |
-
|
| 271 |
alternatives = []
|
| 272 |
for disease, prob in sorted_probs[1:4]: # Top 3 alternatives
|
| 273 |
if prob > 0.05: # Only significant alternatives
|
|
@@ -276,7 +276,7 @@ Be honest about uncertainty. Patient safety is paramount."""
|
|
| 276 |
"probability": prob,
|
| 277 |
"note": "Consider discussing with healthcare provider"
|
| 278 |
})
|
| 279 |
-
|
| 280 |
return alternatives
|
| 281 |
|
| 282 |
|
|
|
|
| 3 |
Confidence Assessor Agent - Evaluates prediction reliability
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
from src.biomarker_validator import BiomarkerValidator
|
| 9 |
from src.llm_config import llm_config
|
| 10 |
+
from src.state import AgentOutput, GuildState
|
| 11 |
|
| 12 |
|
| 13 |
class ConfidenceAssessorAgent:
|
| 14 |
"""Agent that assesses the reliability and limitations of the prediction"""
|
| 15 |
+
|
| 16 |
def __init__(self):
|
| 17 |
self.llm = llm_config.analyzer
|
| 18 |
+
|
| 19 |
def assess(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Assess prediction confidence and identify limitations.
|
|
|
|
| 29 |
print("\n" + "="*70)
|
| 30 |
print("EXECUTING: Confidence Assessor Agent")
|
| 31 |
print("="*70)
|
| 32 |
+
|
| 33 |
model_prediction = state['model_prediction']
|
| 34 |
disease = model_prediction['disease']
|
| 35 |
ml_confidence = model_prediction['confidence']
|
| 36 |
probabilities = model_prediction.get('probabilities', {})
|
| 37 |
biomarkers = state['patient_biomarkers']
|
| 38 |
+
|
| 39 |
# Collect previous agent findings
|
| 40 |
biomarker_analysis = state.get('biomarker_analysis') or {}
|
| 41 |
disease_explanation = self._get_agent_findings(state, "Disease Explainer")
|
| 42 |
linker_findings = self._get_agent_findings(state, "Biomarker-Disease Linker")
|
| 43 |
+
|
| 44 |
print(f"\nAssessing confidence for {disease} prediction...")
|
| 45 |
+
|
| 46 |
# Evaluate evidence strength
|
| 47 |
evidence_strength = self._evaluate_evidence_strength(
|
| 48 |
biomarker_analysis,
|
| 49 |
disease_explanation,
|
| 50 |
linker_findings
|
| 51 |
)
|
| 52 |
+
|
| 53 |
# Identify limitations
|
| 54 |
limitations = self._identify_limitations(
|
| 55 |
biomarkers,
|
| 56 |
biomarker_analysis,
|
| 57 |
probabilities
|
| 58 |
)
|
| 59 |
+
|
| 60 |
# Calculate aggregate reliability
|
| 61 |
reliability = self._calculate_reliability(
|
| 62 |
ml_confidence,
|
| 63 |
evidence_strength,
|
| 64 |
len(limitations)
|
| 65 |
)
|
| 66 |
+
|
| 67 |
# Generate assessment summary
|
| 68 |
assessment_summary = self._generate_assessment(
|
| 69 |
disease,
|
|
|
|
| 72 |
evidence_strength,
|
| 73 |
limitations
|
| 74 |
)
|
| 75 |
+
|
| 76 |
# Create agent output
|
| 77 |
output = AgentOutput(
|
| 78 |
agent_name="Confidence Assessor",
|
|
|
|
| 86 |
"alternative_diagnoses": self._get_alternatives(probabilities)
|
| 87 |
}
|
| 88 |
)
|
| 89 |
+
|
| 90 |
# Update state
|
| 91 |
print("\nConfidence assessment complete")
|
| 92 |
print(f" - Prediction reliability: {reliability}")
|
| 93 |
print(f" - Evidence strength: {evidence_strength}")
|
| 94 |
print(f" - Limitations identified: {len(limitations)}")
|
| 95 |
+
|
| 96 |
return {'agent_outputs': [output]}
|
| 97 |
+
|
| 98 |
def _get_agent_findings(self, state: GuildState, agent_name: str) -> dict:
|
| 99 |
"""Extract findings from a specific agent"""
|
| 100 |
for output in state.get('agent_outputs', []):
|
| 101 |
if output.agent_name == agent_name:
|
| 102 |
return output.findings
|
| 103 |
return {}
|
| 104 |
+
|
| 105 |
def _evaluate_evidence_strength(
|
| 106 |
self,
|
| 107 |
biomarker_analysis: dict,
|
|
|
|
| 109 |
linker_findings: dict
|
| 110 |
) -> str:
|
| 111 |
"""Evaluate the strength of supporting evidence"""
|
| 112 |
+
|
| 113 |
score = 0
|
| 114 |
max_score = 5
|
| 115 |
+
|
| 116 |
# Check biomarker validation quality
|
| 117 |
flags = biomarker_analysis.get('biomarker_flags', [])
|
| 118 |
abnormal_count = len([f for f in flags if f.get('status') != 'NORMAL'])
|
|
|
|
| 120 |
score += 1
|
| 121 |
if abnormal_count >= 5:
|
| 122 |
score += 1
|
| 123 |
+
|
| 124 |
# Check disease explanation quality
|
| 125 |
if disease_explanation.get('retrieval_quality', 0) >= 3:
|
| 126 |
score += 1
|
| 127 |
+
|
| 128 |
# Check biomarker-disease linking
|
| 129 |
key_drivers = linker_findings.get('key_drivers', [])
|
| 130 |
if len(key_drivers) >= 2:
|
| 131 |
score += 1
|
| 132 |
if len(key_drivers) >= 4:
|
| 133 |
score += 1
|
| 134 |
+
|
| 135 |
# Map score to categorical rating
|
| 136 |
if score >= 4:
|
| 137 |
return "STRONG"
|
|
|
|
| 139 |
return "MODERATE"
|
| 140 |
else:
|
| 141 |
return "WEAK"
|
| 142 |
+
|
| 143 |
def _identify_limitations(
|
| 144 |
self,
|
| 145 |
+
biomarkers: dict[str, float],
|
| 146 |
biomarker_analysis: dict,
|
| 147 |
+
probabilities: dict[str, float]
|
| 148 |
+
) -> list[str]:
|
| 149 |
"""Identify limitations and uncertainties"""
|
| 150 |
limitations = []
|
| 151 |
+
|
| 152 |
# Check for missing biomarkers
|
| 153 |
expected_biomarkers = BiomarkerValidator().expected_biomarker_count()
|
| 154 |
if len(biomarkers) < expected_biomarkers:
|
| 155 |
missing = expected_biomarkers - len(biomarkers)
|
| 156 |
limitations.append(f"Missing data: {missing} biomarker(s) not provided")
|
| 157 |
+
|
| 158 |
# Check for close alternative predictions
|
| 159 |
sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
|
| 160 |
if len(sorted_probs) >= 2:
|
|
|
|
| 164 |
limitations.append(
|
| 165 |
f"Differential diagnosis: {top2} also possible ({prob2:.1%} probability)"
|
| 166 |
)
|
| 167 |
+
|
| 168 |
# Check for normal biomarkers despite prediction
|
| 169 |
flags = biomarker_analysis.get('biomarker_flags', [])
|
| 170 |
relevant = biomarker_analysis.get('relevant_biomarkers', [])
|
|
|
|
| 174 |
]
|
| 175 |
if len(normal_relevant) >= 2:
|
| 176 |
limitations.append(
|
| 177 |
+
"Some disease-relevant biomarkers are within normal range"
|
| 178 |
)
|
| 179 |
+
|
| 180 |
# Check for safety alerts (indicates complexity)
|
| 181 |
alerts = biomarker_analysis.get('safety_alerts', [])
|
| 182 |
if len(alerts) >= 2:
|
| 183 |
limitations.append(
|
| 184 |
"Multiple critical values detected; professional evaluation essential"
|
| 185 |
)
|
| 186 |
+
|
| 187 |
return limitations
|
| 188 |
+
|
| 189 |
def _calculate_reliability(
|
| 190 |
self,
|
| 191 |
ml_confidence: float,
|
|
|
|
| 193 |
limitation_count: int
|
| 194 |
) -> str:
|
| 195 |
"""Calculate overall prediction reliability"""
|
| 196 |
+
|
| 197 |
score = 0
|
| 198 |
+
|
| 199 |
# ML confidence contribution
|
| 200 |
if ml_confidence >= 0.8:
|
| 201 |
score += 3
|
|
|
|
| 203 |
score += 2
|
| 204 |
elif ml_confidence >= 0.4:
|
| 205 |
score += 1
|
| 206 |
+
|
| 207 |
# Evidence strength contribution
|
| 208 |
if evidence_strength == "STRONG":
|
| 209 |
score += 3
|
|
|
|
| 211 |
score += 2
|
| 212 |
else:
|
| 213 |
score += 1
|
| 214 |
+
|
| 215 |
# Limitation penalty
|
| 216 |
score -= min(limitation_count, 3)
|
| 217 |
+
|
| 218 |
# Map to categorical
|
| 219 |
if score >= 5:
|
| 220 |
return "HIGH"
|
|
|
|
| 222 |
return "MODERATE"
|
| 223 |
else:
|
| 224 |
return "LOW"
|
| 225 |
+
|
| 226 |
def _generate_assessment(
|
| 227 |
self,
|
| 228 |
disease: str,
|
| 229 |
ml_confidence: float,
|
| 230 |
reliability: str,
|
| 231 |
evidence_strength: str,
|
| 232 |
+
limitations: list[str]
|
| 233 |
) -> str:
|
| 234 |
"""Generate human-readable assessment summary"""
|
| 235 |
+
|
| 236 |
prompt = f"""As a medical AI assessment system, provide a brief confidence statement about this prediction:
|
| 237 |
|
| 238 |
Disease Predicted: {disease}
|
|
|
|
| 254 |
except Exception as e:
|
| 255 |
print(f"Warning: Assessment generation failed: {e}")
|
| 256 |
return f"The {disease} prediction has {reliability.lower()} reliability based on available data. Professional medical evaluation is strongly recommended for accurate diagnosis."
|
| 257 |
+
|
| 258 |
def _get_recommendation(self, reliability: str) -> str:
|
| 259 |
"""Get action recommendation based on reliability"""
|
| 260 |
if reliability == "HIGH":
|
|
|
|
| 263 |
return "Moderate confidence prediction. Medical consultation recommended for professional evaluation and additional testing if needed."
|
| 264 |
else:
|
| 265 |
return "Low confidence prediction. Professional medical assessment essential. Additional tests may be required for accurate diagnosis."
|
| 266 |
+
|
| 267 |
+
def _get_alternatives(self, probabilities: dict[str, float]) -> list[dict[str, Any]]:
|
| 268 |
"""Get alternative diagnoses to consider"""
|
| 269 |
sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
|
| 270 |
+
|
| 271 |
alternatives = []
|
| 272 |
for disease, prob in sorted_probs[1:4]: # Top 3 alternatives
|
| 273 |
if prob > 0.05: # Only significant alternatives
|
|
|
|
| 276 |
"probability": prob,
|
| 277 |
"note": "Consider discussing with healthcare provider"
|
| 278 |
})
|
| 279 |
+
|
| 280 |
return alternatives
|
| 281 |
|
| 282 |
|
src/agents/disease_explainer.py
CHANGED
|
@@ -4,14 +4,16 @@ Disease Explainer Agent - Retrieves disease pathophysiology from medical PDFs
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
from src.llm_config import llm_config
|
| 9 |
from langchain_core.prompts import ChatPromptTemplate
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class DiseaseExplainerAgent:
|
| 13 |
"""Agent that retrieves and explains disease mechanisms using RAG"""
|
| 14 |
-
|
| 15 |
def __init__(self, retriever):
|
| 16 |
"""
|
| 17 |
Initialize with a retriever for medical PDFs.
|
|
@@ -21,7 +23,7 @@ class DiseaseExplainerAgent:
|
|
| 21 |
"""
|
| 22 |
self.retriever = retriever
|
| 23 |
self.llm = llm_config.explainer
|
| 24 |
-
|
| 25 |
def explain(self, state: GuildState) -> GuildState:
|
| 26 |
"""
|
| 27 |
Retrieve and explain disease pathophysiology.
|
|
@@ -35,23 +37,23 @@ class DiseaseExplainerAgent:
|
|
| 35 |
print("\n" + "="*70)
|
| 36 |
print("EXECUTING: Disease Explainer Agent (RAG)")
|
| 37 |
print("="*70)
|
| 38 |
-
|
| 39 |
model_prediction = state['model_prediction']
|
| 40 |
disease = model_prediction['disease']
|
| 41 |
confidence = model_prediction['confidence']
|
| 42 |
-
|
| 43 |
# Configure retrieval based on SOP — create a copy to avoid mutating shared retriever
|
| 44 |
retrieval_k = state['sop'].disease_explainer_k
|
| 45 |
original_search_kwargs = dict(self.retriever.search_kwargs)
|
| 46 |
self.retriever.search_kwargs = {**original_search_kwargs, 'k': retrieval_k}
|
| 47 |
-
|
| 48 |
# Retrieve relevant documents
|
| 49 |
print(f"\nRetrieving information about: {disease}")
|
| 50 |
print(f"Retrieval k={state['sop'].disease_explainer_k}")
|
| 51 |
-
|
| 52 |
query = f"""What is {disease}? Explain the pathophysiology, diagnostic criteria,
|
| 53 |
and clinical presentation. Focus on mechanisms relevant to blood biomarkers."""
|
| 54 |
-
|
| 55 |
try:
|
| 56 |
docs = self.retriever.invoke(query)
|
| 57 |
finally:
|
|
@@ -87,13 +89,13 @@ class DiseaseExplainerAgent:
|
|
| 87 |
print(" - Pathophysiology: insufficient evidence")
|
| 88 |
print(" - Citations: 0 sources")
|
| 89 |
return {'agent_outputs': [output]}
|
| 90 |
-
|
| 91 |
# Generate explanation
|
| 92 |
explanation = self._generate_explanation(disease, docs, confidence)
|
| 93 |
-
|
| 94 |
# Extract citations
|
| 95 |
citations = self._extract_citations(docs)
|
| 96 |
-
|
| 97 |
# Create agent output
|
| 98 |
output = AgentOutput(
|
| 99 |
agent_name="Disease Explainer",
|
|
@@ -109,23 +111,23 @@ class DiseaseExplainerAgent:
|
|
| 109 |
"citations_missing": False
|
| 110 |
}
|
| 111 |
)
|
| 112 |
-
|
| 113 |
# Update state
|
| 114 |
print("\nDisease explanation generated")
|
| 115 |
print(f" - Pathophysiology: {len(explanation['pathophysiology'])} chars")
|
| 116 |
print(f" - Citations: {len(citations)} sources")
|
| 117 |
-
|
| 118 |
return {'agent_outputs': [output]}
|
| 119 |
-
|
| 120 |
def _generate_explanation(self, disease: str, docs: list, confidence: float) -> dict:
|
| 121 |
"""Generate structured disease explanation using LLM and retrieved docs"""
|
| 122 |
-
|
| 123 |
# Format retrieved context
|
| 124 |
context = "\n\n---\n\n".join([
|
| 125 |
f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
|
| 126 |
for doc in docs
|
| 127 |
])
|
| 128 |
-
|
| 129 |
prompt = ChatPromptTemplate.from_messages([
|
| 130 |
("system", """You are a medical expert explaining diseases for patient self-assessment.
|
| 131 |
Based on the provided medical literature, explain the disease in clear, accessible language.
|
|
@@ -144,20 +146,20 @@ class DiseaseExplainerAgent:
|
|
| 144 |
|
| 145 |
Please provide a structured explanation.""")
|
| 146 |
])
|
| 147 |
-
|
| 148 |
chain = prompt | self.llm
|
| 149 |
-
|
| 150 |
try:
|
| 151 |
response = chain.invoke({
|
| 152 |
"disease": disease,
|
| 153 |
"confidence": confidence,
|
| 154 |
"context": context
|
| 155 |
})
|
| 156 |
-
|
| 157 |
# Parse structured response
|
| 158 |
content = response.content
|
| 159 |
explanation = self._parse_explanation(content)
|
| 160 |
-
|
| 161 |
except Exception as e:
|
| 162 |
print(f"Warning: LLM explanation generation failed: {e}")
|
| 163 |
explanation = {
|
|
@@ -166,9 +168,9 @@ class DiseaseExplainerAgent:
|
|
| 166 |
"clinical_presentation": "Clinical presentation varies by individual.",
|
| 167 |
"summary": f"{disease} detected with {confidence:.1%} confidence. Consult healthcare provider."
|
| 168 |
}
|
| 169 |
-
|
| 170 |
return explanation
|
| 171 |
-
|
| 172 |
def _parse_explanation(self, content: str) -> dict:
|
| 173 |
"""Parse LLM response into structured sections"""
|
| 174 |
sections = {
|
|
@@ -177,14 +179,14 @@ class DiseaseExplainerAgent:
|
|
| 177 |
"clinical_presentation": "",
|
| 178 |
"summary": ""
|
| 179 |
}
|
| 180 |
-
|
| 181 |
# Simple parsing logic
|
| 182 |
current_section = None
|
| 183 |
lines = content.split('\n')
|
| 184 |
-
|
| 185 |
for line in lines:
|
| 186 |
line_upper = line.upper().strip()
|
| 187 |
-
|
| 188 |
if 'PATHOPHYSIOLOGY' in line_upper:
|
| 189 |
current_section = 'pathophysiology'
|
| 190 |
elif 'DIAGNOSTIC' in line_upper:
|
|
@@ -195,31 +197,31 @@ class DiseaseExplainerAgent:
|
|
| 195 |
current_section = 'summary'
|
| 196 |
elif current_section and line.strip():
|
| 197 |
sections[current_section] += line + "\n"
|
| 198 |
-
|
| 199 |
# If parsing failed, use full content as summary
|
| 200 |
if not any(sections.values()):
|
| 201 |
sections['summary'] = content[:500]
|
| 202 |
-
|
| 203 |
return sections
|
| 204 |
-
|
| 205 |
def _extract_citations(self, docs: list) -> list:
|
| 206 |
"""Extract citations from retrieved documents"""
|
| 207 |
citations = []
|
| 208 |
-
|
| 209 |
for doc in docs:
|
| 210 |
source = doc.metadata.get('source', 'Unknown')
|
| 211 |
page = doc.metadata.get('page', 'N/A')
|
| 212 |
-
|
| 213 |
# Clean up source path
|
| 214 |
if '\\' in source or '/' in source:
|
| 215 |
source = Path(source).name
|
| 216 |
-
|
| 217 |
citation = f"{source}"
|
| 218 |
if page != 'N/A':
|
| 219 |
citation += f" (Page {page})"
|
| 220 |
-
|
| 221 |
citations.append(citation)
|
| 222 |
-
|
| 223 |
return citations
|
| 224 |
|
| 225 |
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from pathlib import Path
|
| 7 |
+
|
|
|
|
| 8 |
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
|
| 10 |
+
from src.llm_config import llm_config
|
| 11 |
+
from src.state import AgentOutput, GuildState
|
| 12 |
+
|
| 13 |
|
| 14 |
class DiseaseExplainerAgent:
|
| 15 |
"""Agent that retrieves and explains disease mechanisms using RAG"""
|
| 16 |
+
|
| 17 |
def __init__(self, retriever):
|
| 18 |
"""
|
| 19 |
Initialize with a retriever for medical PDFs.
|
|
|
|
| 23 |
"""
|
| 24 |
self.retriever = retriever
|
| 25 |
self.llm = llm_config.explainer
|
| 26 |
+
|
| 27 |
def explain(self, state: GuildState) -> GuildState:
|
| 28 |
"""
|
| 29 |
Retrieve and explain disease pathophysiology.
|
|
|
|
| 37 |
print("\n" + "="*70)
|
| 38 |
print("EXECUTING: Disease Explainer Agent (RAG)")
|
| 39 |
print("="*70)
|
| 40 |
+
|
| 41 |
model_prediction = state['model_prediction']
|
| 42 |
disease = model_prediction['disease']
|
| 43 |
confidence = model_prediction['confidence']
|
| 44 |
+
|
| 45 |
# Configure retrieval based on SOP — create a copy to avoid mutating shared retriever
|
| 46 |
retrieval_k = state['sop'].disease_explainer_k
|
| 47 |
original_search_kwargs = dict(self.retriever.search_kwargs)
|
| 48 |
self.retriever.search_kwargs = {**original_search_kwargs, 'k': retrieval_k}
|
| 49 |
+
|
| 50 |
# Retrieve relevant documents
|
| 51 |
print(f"\nRetrieving information about: {disease}")
|
| 52 |
print(f"Retrieval k={state['sop'].disease_explainer_k}")
|
| 53 |
+
|
| 54 |
query = f"""What is {disease}? Explain the pathophysiology, diagnostic criteria,
|
| 55 |
and clinical presentation. Focus on mechanisms relevant to blood biomarkers."""
|
| 56 |
+
|
| 57 |
try:
|
| 58 |
docs = self.retriever.invoke(query)
|
| 59 |
finally:
|
|
|
|
| 89 |
print(" - Pathophysiology: insufficient evidence")
|
| 90 |
print(" - Citations: 0 sources")
|
| 91 |
return {'agent_outputs': [output]}
|
| 92 |
+
|
| 93 |
# Generate explanation
|
| 94 |
explanation = self._generate_explanation(disease, docs, confidence)
|
| 95 |
+
|
| 96 |
# Extract citations
|
| 97 |
citations = self._extract_citations(docs)
|
| 98 |
+
|
| 99 |
# Create agent output
|
| 100 |
output = AgentOutput(
|
| 101 |
agent_name="Disease Explainer",
|
|
|
|
| 111 |
"citations_missing": False
|
| 112 |
}
|
| 113 |
)
|
| 114 |
+
|
| 115 |
# Update state
|
| 116 |
print("\nDisease explanation generated")
|
| 117 |
print(f" - Pathophysiology: {len(explanation['pathophysiology'])} chars")
|
| 118 |
print(f" - Citations: {len(citations)} sources")
|
| 119 |
+
|
| 120 |
return {'agent_outputs': [output]}
|
| 121 |
+
|
| 122 |
def _generate_explanation(self, disease: str, docs: list, confidence: float) -> dict:
|
| 123 |
"""Generate structured disease explanation using LLM and retrieved docs"""
|
| 124 |
+
|
| 125 |
# Format retrieved context
|
| 126 |
context = "\n\n---\n\n".join([
|
| 127 |
f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
|
| 128 |
for doc in docs
|
| 129 |
])
|
| 130 |
+
|
| 131 |
prompt = ChatPromptTemplate.from_messages([
|
| 132 |
("system", """You are a medical expert explaining diseases for patient self-assessment.
|
| 133 |
Based on the provided medical literature, explain the disease in clear, accessible language.
|
|
|
|
| 146 |
|
| 147 |
Please provide a structured explanation.""")
|
| 148 |
])
|
| 149 |
+
|
| 150 |
chain = prompt | self.llm
|
| 151 |
+
|
| 152 |
try:
|
| 153 |
response = chain.invoke({
|
| 154 |
"disease": disease,
|
| 155 |
"confidence": confidence,
|
| 156 |
"context": context
|
| 157 |
})
|
| 158 |
+
|
| 159 |
# Parse structured response
|
| 160 |
content = response.content
|
| 161 |
explanation = self._parse_explanation(content)
|
| 162 |
+
|
| 163 |
except Exception as e:
|
| 164 |
print(f"Warning: LLM explanation generation failed: {e}")
|
| 165 |
explanation = {
|
|
|
|
| 168 |
"clinical_presentation": "Clinical presentation varies by individual.",
|
| 169 |
"summary": f"{disease} detected with {confidence:.1%} confidence. Consult healthcare provider."
|
| 170 |
}
|
| 171 |
+
|
| 172 |
return explanation
|
| 173 |
+
|
| 174 |
def _parse_explanation(self, content: str) -> dict:
|
| 175 |
"""Parse LLM response into structured sections"""
|
| 176 |
sections = {
|
|
|
|
| 179 |
"clinical_presentation": "",
|
| 180 |
"summary": ""
|
| 181 |
}
|
| 182 |
+
|
| 183 |
# Simple parsing logic
|
| 184 |
current_section = None
|
| 185 |
lines = content.split('\n')
|
| 186 |
+
|
| 187 |
for line in lines:
|
| 188 |
line_upper = line.upper().strip()
|
| 189 |
+
|
| 190 |
if 'PATHOPHYSIOLOGY' in line_upper:
|
| 191 |
current_section = 'pathophysiology'
|
| 192 |
elif 'DIAGNOSTIC' in line_upper:
|
|
|
|
| 197 |
current_section = 'summary'
|
| 198 |
elif current_section and line.strip():
|
| 199 |
sections[current_section] += line + "\n"
|
| 200 |
+
|
| 201 |
# If parsing failed, use full content as summary
|
| 202 |
if not any(sections.values()):
|
| 203 |
sections['summary'] = content[:500]
|
| 204 |
+
|
| 205 |
return sections
|
| 206 |
+
|
| 207 |
def _extract_citations(self, docs: list) -> list:
|
| 208 |
"""Extract citations from retrieved documents"""
|
| 209 |
citations = []
|
| 210 |
+
|
| 211 |
for doc in docs:
|
| 212 |
source = doc.metadata.get('source', 'Unknown')
|
| 213 |
page = doc.metadata.get('page', 'N/A')
|
| 214 |
+
|
| 215 |
# Clean up source path
|
| 216 |
if '\\' in source or '/' in source:
|
| 217 |
source = Path(source).name
|
| 218 |
+
|
| 219 |
citation = f"{source}"
|
| 220 |
if page != 'N/A':
|
| 221 |
citation += f" (Page {page})"
|
| 222 |
+
|
| 223 |
citations.append(citation)
|
| 224 |
+
|
| 225 |
return citations
|
| 226 |
|
| 227 |
|
src/agents/response_synthesizer.py
CHANGED
|
@@ -3,19 +3,20 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Response Synthesizer Agent - Compiles all findings into final structured JSON
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
from src.state import GuildState
|
| 9 |
-
from src.llm_config import llm_config
|
| 10 |
from langchain_core.prompts import ChatPromptTemplate
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class ResponseSynthesizerAgent:
|
| 14 |
"""Agent that synthesizes all specialist findings into the final response"""
|
| 15 |
-
|
| 16 |
def __init__(self):
|
| 17 |
self.llm = llm_config.get_synthesizer()
|
| 18 |
-
|
| 19 |
def synthesize(self, state: GuildState) -> GuildState:
|
| 20 |
"""
|
| 21 |
Synthesize all agent outputs into final response.
|
|
@@ -29,17 +30,17 @@ class ResponseSynthesizerAgent:
|
|
| 29 |
print("\n" + "="*70)
|
| 30 |
print("EXECUTING: Response Synthesizer Agent")
|
| 31 |
print("="*70)
|
| 32 |
-
|
| 33 |
model_prediction = state['model_prediction']
|
| 34 |
patient_biomarkers = state['patient_biomarkers']
|
| 35 |
patient_context = state.get('patient_context', {})
|
| 36 |
agent_outputs = state.get('agent_outputs', [])
|
| 37 |
-
|
| 38 |
# Collect findings from all agents
|
| 39 |
findings = self._collect_findings(agent_outputs)
|
| 40 |
-
|
| 41 |
print(f"\nSynthesizing findings from {len(agent_outputs)} specialist agents...")
|
| 42 |
-
|
| 43 |
# Build structured response
|
| 44 |
recs = self._build_recommendations(findings)
|
| 45 |
response = {
|
|
@@ -64,38 +65,38 @@ class ResponseSynthesizerAgent:
|
|
| 64 |
"alternative_diagnoses": self._build_alternative_diagnoses(findings)
|
| 65 |
}
|
| 66 |
}
|
| 67 |
-
|
| 68 |
# Generate patient-friendly summary
|
| 69 |
response["patient_summary"]["narrative"] = self._generate_narrative_summary(
|
| 70 |
model_prediction,
|
| 71 |
findings,
|
| 72 |
response
|
| 73 |
)
|
| 74 |
-
|
| 75 |
print("\nResponse synthesis complete")
|
| 76 |
-
print(
|
| 77 |
print(f" - Prediction explanation: {len(response['prediction_explanation']['key_drivers'])} key drivers")
|
| 78 |
print(f" - Recommendations: {len(response['clinical_recommendations']['immediate_actions'])} immediate actions")
|
| 79 |
print(f" - Safety alerts: {len(response['safety_alerts'])} alerts")
|
| 80 |
-
|
| 81 |
return {'final_response': response}
|
| 82 |
-
|
| 83 |
-
def _collect_findings(self, agent_outputs:
|
| 84 |
"""Organize all agent findings by agent name"""
|
| 85 |
findings = {}
|
| 86 |
for output in agent_outputs:
|
| 87 |
findings[output.agent_name] = output.findings
|
| 88 |
return findings
|
| 89 |
-
|
| 90 |
-
def _build_patient_summary(self, biomarkers:
|
| 91 |
"""Build patient summary section"""
|
| 92 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 93 |
flags = biomarker_analysis.get('biomarker_flags', [])
|
| 94 |
-
|
| 95 |
# Count biomarker statuses
|
| 96 |
critical = len([f for f in flags if 'CRITICAL' in f.get('status', '')])
|
| 97 |
abnormal = len([f for f in flags if f.get('status') != 'NORMAL'])
|
| 98 |
-
|
| 99 |
return {
|
| 100 |
"total_biomarkers_tested": len(biomarkers),
|
| 101 |
"biomarkers_in_normal_range": len(flags) - abnormal,
|
|
@@ -104,15 +105,15 @@ class ResponseSynthesizerAgent:
|
|
| 104 |
"overall_risk_profile": biomarker_analysis.get('summary', 'Assessment complete'),
|
| 105 |
"narrative": "" # Will be filled later
|
| 106 |
}
|
| 107 |
-
|
| 108 |
-
def _build_prediction_explanation(self, model_prediction:
|
| 109 |
"""Build prediction explanation section"""
|
| 110 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 111 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 112 |
-
|
| 113 |
disease = model_prediction['disease']
|
| 114 |
confidence = model_prediction['confidence']
|
| 115 |
-
|
| 116 |
# Get key drivers
|
| 117 |
key_drivers_raw = linker_findings.get('key_drivers', [])
|
| 118 |
key_drivers = [
|
|
@@ -125,7 +126,7 @@ class ResponseSynthesizerAgent:
|
|
| 125 |
}
|
| 126 |
for kd in key_drivers_raw
|
| 127 |
]
|
| 128 |
-
|
| 129 |
return {
|
| 130 |
"primary_disease": disease,
|
| 131 |
"confidence": confidence,
|
|
@@ -135,37 +136,37 @@ class ResponseSynthesizerAgent:
|
|
| 135 |
"pdf_references": disease_explanation.get('citations', [])
|
| 136 |
}
|
| 137 |
|
| 138 |
-
def _build_biomarker_flags(self, findings:
|
| 139 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 140 |
return biomarker_analysis.get('biomarker_flags', [])
|
| 141 |
|
| 142 |
-
def _build_key_drivers(self, findings:
|
| 143 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 144 |
return linker_findings.get('key_drivers', [])
|
| 145 |
|
| 146 |
-
def _build_disease_explanation(self, findings:
|
| 147 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 148 |
return {
|
| 149 |
"pathophysiology": disease_explanation.get('pathophysiology', ''),
|
| 150 |
"citations": disease_explanation.get('citations', []),
|
| 151 |
"retrieved_chunks": disease_explanation.get('retrieved_chunks')
|
| 152 |
}
|
| 153 |
-
|
| 154 |
-
def _build_recommendations(self, findings:
|
| 155 |
"""Build clinical recommendations section"""
|
| 156 |
guidelines = findings.get("Clinical Guidelines", {})
|
| 157 |
-
|
| 158 |
return {
|
| 159 |
"immediate_actions": guidelines.get('immediate_actions', []),
|
| 160 |
"lifestyle_changes": guidelines.get('lifestyle_changes', []),
|
| 161 |
"monitoring": guidelines.get('monitoring', []),
|
| 162 |
"guideline_citations": guidelines.get('guideline_citations', [])
|
| 163 |
}
|
| 164 |
-
|
| 165 |
-
def _build_confidence_assessment(self, findings:
|
| 166 |
"""Build confidence assessment section"""
|
| 167 |
assessment = findings.get("Confidence Assessor", {})
|
| 168 |
-
|
| 169 |
return {
|
| 170 |
"prediction_reliability": assessment.get('prediction_reliability', 'UNKNOWN'),
|
| 171 |
"evidence_strength": assessment.get('evidence_strength', 'UNKNOWN'),
|
|
@@ -175,19 +176,19 @@ class ResponseSynthesizerAgent:
|
|
| 175 |
"alternative_diagnoses": assessment.get('alternative_diagnoses', [])
|
| 176 |
}
|
| 177 |
|
| 178 |
-
def _build_alternative_diagnoses(self, findings:
|
| 179 |
assessment = findings.get("Confidence Assessor", {})
|
| 180 |
return assessment.get('alternative_diagnoses', [])
|
| 181 |
-
|
| 182 |
-
def _build_safety_alerts(self, findings:
|
| 183 |
"""Build safety alerts section"""
|
| 184 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 185 |
return biomarker_analysis.get('safety_alerts', [])
|
| 186 |
-
|
| 187 |
-
def _build_metadata(self, state: GuildState) ->
|
| 188 |
"""Build metadata section"""
|
| 189 |
from datetime import datetime
|
| 190 |
-
|
| 191 |
return {
|
| 192 |
"timestamp": datetime.now().isoformat(),
|
| 193 |
"system_version": "MediGuard AI RAG-Helper v1.0",
|
|
@@ -195,24 +196,24 @@ class ResponseSynthesizerAgent:
|
|
| 195 |
"agents_executed": [output.agent_name for output in state.get('agent_outputs', [])],
|
| 196 |
"disclaimer": "This is an AI-assisted analysis tool for patient self-assessment. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always consult qualified healthcare providers for medical decisions."
|
| 197 |
}
|
| 198 |
-
|
| 199 |
def _generate_narrative_summary(
|
| 200 |
self,
|
| 201 |
model_prediction,
|
| 202 |
-
findings:
|
| 203 |
-
response:
|
| 204 |
) -> str:
|
| 205 |
"""Generate a patient-friendly narrative summary using LLM"""
|
| 206 |
-
|
| 207 |
disease = model_prediction['disease']
|
| 208 |
confidence = model_prediction['confidence']
|
| 209 |
reliability = response['confidence_assessment']['prediction_reliability']
|
| 210 |
-
|
| 211 |
# Get key points
|
| 212 |
critical_count = response['patient_summary']['critical_values']
|
| 213 |
abnormal_count = response['patient_summary']['biomarkers_out_of_range']
|
| 214 |
key_drivers = response['prediction_explanation']['key_drivers']
|
| 215 |
-
|
| 216 |
prompt = ChatPromptTemplate.from_messages([
|
| 217 |
("system", """You are a medical AI assistant explaining test results to a patient.
|
| 218 |
Write a clear, compassionate 3-4 sentence summary that:
|
|
@@ -231,12 +232,12 @@ class ResponseSynthesizerAgent:
|
|
| 231 |
|
| 232 |
Write a compassionate patient summary.""")
|
| 233 |
])
|
| 234 |
-
|
| 235 |
chain = prompt | self.llm
|
| 236 |
-
|
| 237 |
try:
|
| 238 |
driver_names = [kd['biomarker'] for kd in key_drivers[:3]]
|
| 239 |
-
|
| 240 |
response_obj = chain.invoke({
|
| 241 |
"disease": disease,
|
| 242 |
"confidence": confidence,
|
|
@@ -245,9 +246,9 @@ class ResponseSynthesizerAgent:
|
|
| 245 |
"abnormal": abnormal_count,
|
| 246 |
"drivers": ", ".join(driver_names) if driver_names else "Multiple biomarkers"
|
| 247 |
})
|
| 248 |
-
|
| 249 |
return response_obj.content.strip()
|
| 250 |
-
|
| 251 |
except Exception as e:
|
| 252 |
print(f"Warning: Narrative generation failed: {e}")
|
| 253 |
return f"Your test results suggest {disease} with {confidence:.1%} confidence. {abnormal_count} biomarker(s) are out of normal range. Please consult with a healthcare provider for professional evaluation and guidance."
|
|
|
|
| 3 |
Response Synthesizer Agent - Compiles all findings into final structured JSON
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
|
|
|
|
|
|
| 8 |
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
|
| 10 |
+
from src.llm_config import llm_config
|
| 11 |
+
from src.state import GuildState
|
| 12 |
+
|
| 13 |
|
| 14 |
class ResponseSynthesizerAgent:
|
| 15 |
"""Agent that synthesizes all specialist findings into the final response"""
|
| 16 |
+
|
| 17 |
def __init__(self):
|
| 18 |
self.llm = llm_config.get_synthesizer()
|
| 19 |
+
|
| 20 |
def synthesize(self, state: GuildState) -> GuildState:
|
| 21 |
"""
|
| 22 |
Synthesize all agent outputs into final response.
|
|
|
|
| 30 |
print("\n" + "="*70)
|
| 31 |
print("EXECUTING: Response Synthesizer Agent")
|
| 32 |
print("="*70)
|
| 33 |
+
|
| 34 |
model_prediction = state['model_prediction']
|
| 35 |
patient_biomarkers = state['patient_biomarkers']
|
| 36 |
patient_context = state.get('patient_context', {})
|
| 37 |
agent_outputs = state.get('agent_outputs', [])
|
| 38 |
+
|
| 39 |
# Collect findings from all agents
|
| 40 |
findings = self._collect_findings(agent_outputs)
|
| 41 |
+
|
| 42 |
print(f"\nSynthesizing findings from {len(agent_outputs)} specialist agents...")
|
| 43 |
+
|
| 44 |
# Build structured response
|
| 45 |
recs = self._build_recommendations(findings)
|
| 46 |
response = {
|
|
|
|
| 65 |
"alternative_diagnoses": self._build_alternative_diagnoses(findings)
|
| 66 |
}
|
| 67 |
}
|
| 68 |
+
|
| 69 |
# Generate patient-friendly summary
|
| 70 |
response["patient_summary"]["narrative"] = self._generate_narrative_summary(
|
| 71 |
model_prediction,
|
| 72 |
findings,
|
| 73 |
response
|
| 74 |
)
|
| 75 |
+
|
| 76 |
print("\nResponse synthesis complete")
|
| 77 |
+
print(" - Patient summary: Generated")
|
| 78 |
print(f" - Prediction explanation: {len(response['prediction_explanation']['key_drivers'])} key drivers")
|
| 79 |
print(f" - Recommendations: {len(response['clinical_recommendations']['immediate_actions'])} immediate actions")
|
| 80 |
print(f" - Safety alerts: {len(response['safety_alerts'])} alerts")
|
| 81 |
+
|
| 82 |
return {'final_response': response}
|
| 83 |
+
|
| 84 |
+
def _collect_findings(self, agent_outputs: list) -> dict[str, Any]:
|
| 85 |
"""Organize all agent findings by agent name"""
|
| 86 |
findings = {}
|
| 87 |
for output in agent_outputs:
|
| 88 |
findings[output.agent_name] = output.findings
|
| 89 |
return findings
|
| 90 |
+
|
| 91 |
+
def _build_patient_summary(self, biomarkers: dict, findings: dict) -> dict:
|
| 92 |
"""Build patient summary section"""
|
| 93 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 94 |
flags = biomarker_analysis.get('biomarker_flags', [])
|
| 95 |
+
|
| 96 |
# Count biomarker statuses
|
| 97 |
critical = len([f for f in flags if 'CRITICAL' in f.get('status', '')])
|
| 98 |
abnormal = len([f for f in flags if f.get('status') != 'NORMAL'])
|
| 99 |
+
|
| 100 |
return {
|
| 101 |
"total_biomarkers_tested": len(biomarkers),
|
| 102 |
"biomarkers_in_normal_range": len(flags) - abnormal,
|
|
|
|
| 105 |
"overall_risk_profile": biomarker_analysis.get('summary', 'Assessment complete'),
|
| 106 |
"narrative": "" # Will be filled later
|
| 107 |
}
|
| 108 |
+
|
| 109 |
+
def _build_prediction_explanation(self, model_prediction: dict, findings: dict) -> dict:
|
| 110 |
"""Build prediction explanation section"""
|
| 111 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 112 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 113 |
+
|
| 114 |
disease = model_prediction['disease']
|
| 115 |
confidence = model_prediction['confidence']
|
| 116 |
+
|
| 117 |
# Get key drivers
|
| 118 |
key_drivers_raw = linker_findings.get('key_drivers', [])
|
| 119 |
key_drivers = [
|
|
|
|
| 126 |
}
|
| 127 |
for kd in key_drivers_raw
|
| 128 |
]
|
| 129 |
+
|
| 130 |
return {
|
| 131 |
"primary_disease": disease,
|
| 132 |
"confidence": confidence,
|
|
|
|
| 136 |
"pdf_references": disease_explanation.get('citations', [])
|
| 137 |
}
|
| 138 |
|
| 139 |
+
def _build_biomarker_flags(self, findings: dict) -> list[dict]:
|
| 140 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 141 |
return biomarker_analysis.get('biomarker_flags', [])
|
| 142 |
|
| 143 |
+
def _build_key_drivers(self, findings: dict) -> list[dict]:
|
| 144 |
linker_findings = findings.get("Biomarker-Disease Linker", {})
|
| 145 |
return linker_findings.get('key_drivers', [])
|
| 146 |
|
| 147 |
+
def _build_disease_explanation(self, findings: dict) -> dict:
|
| 148 |
disease_explanation = findings.get("Disease Explainer", {})
|
| 149 |
return {
|
| 150 |
"pathophysiology": disease_explanation.get('pathophysiology', ''),
|
| 151 |
"citations": disease_explanation.get('citations', []),
|
| 152 |
"retrieved_chunks": disease_explanation.get('retrieved_chunks')
|
| 153 |
}
|
| 154 |
+
|
| 155 |
+
def _build_recommendations(self, findings: dict) -> dict:
|
| 156 |
"""Build clinical recommendations section"""
|
| 157 |
guidelines = findings.get("Clinical Guidelines", {})
|
| 158 |
+
|
| 159 |
return {
|
| 160 |
"immediate_actions": guidelines.get('immediate_actions', []),
|
| 161 |
"lifestyle_changes": guidelines.get('lifestyle_changes', []),
|
| 162 |
"monitoring": guidelines.get('monitoring', []),
|
| 163 |
"guideline_citations": guidelines.get('guideline_citations', [])
|
| 164 |
}
|
| 165 |
+
|
| 166 |
+
def _build_confidence_assessment(self, findings: dict) -> dict:
|
| 167 |
"""Build confidence assessment section"""
|
| 168 |
assessment = findings.get("Confidence Assessor", {})
|
| 169 |
+
|
| 170 |
return {
|
| 171 |
"prediction_reliability": assessment.get('prediction_reliability', 'UNKNOWN'),
|
| 172 |
"evidence_strength": assessment.get('evidence_strength', 'UNKNOWN'),
|
|
|
|
| 176 |
"alternative_diagnoses": assessment.get('alternative_diagnoses', [])
|
| 177 |
}
|
| 178 |
|
| 179 |
+
def _build_alternative_diagnoses(self, findings: dict) -> list[dict]:
|
| 180 |
assessment = findings.get("Confidence Assessor", {})
|
| 181 |
return assessment.get('alternative_diagnoses', [])
|
| 182 |
+
|
| 183 |
+
def _build_safety_alerts(self, findings: dict) -> list[dict]:
|
| 184 |
"""Build safety alerts section"""
|
| 185 |
biomarker_analysis = findings.get("Biomarker Analyzer", {})
|
| 186 |
return biomarker_analysis.get('safety_alerts', [])
|
| 187 |
+
|
| 188 |
+
def _build_metadata(self, state: GuildState) -> dict:
|
| 189 |
"""Build metadata section"""
|
| 190 |
from datetime import datetime
|
| 191 |
+
|
| 192 |
return {
|
| 193 |
"timestamp": datetime.now().isoformat(),
|
| 194 |
"system_version": "MediGuard AI RAG-Helper v1.0",
|
|
|
|
| 196 |
"agents_executed": [output.agent_name for output in state.get('agent_outputs', [])],
|
| 197 |
"disclaimer": "This is an AI-assisted analysis tool for patient self-assessment. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always consult qualified healthcare providers for medical decisions."
|
| 198 |
}
|
| 199 |
+
|
| 200 |
def _generate_narrative_summary(
|
| 201 |
self,
|
| 202 |
model_prediction,
|
| 203 |
+
findings: dict,
|
| 204 |
+
response: dict
|
| 205 |
) -> str:
|
| 206 |
"""Generate a patient-friendly narrative summary using LLM"""
|
| 207 |
+
|
| 208 |
disease = model_prediction['disease']
|
| 209 |
confidence = model_prediction['confidence']
|
| 210 |
reliability = response['confidence_assessment']['prediction_reliability']
|
| 211 |
+
|
| 212 |
# Get key points
|
| 213 |
critical_count = response['patient_summary']['critical_values']
|
| 214 |
abnormal_count = response['patient_summary']['biomarkers_out_of_range']
|
| 215 |
key_drivers = response['prediction_explanation']['key_drivers']
|
| 216 |
+
|
| 217 |
prompt = ChatPromptTemplate.from_messages([
|
| 218 |
("system", """You are a medical AI assistant explaining test results to a patient.
|
| 219 |
Write a clear, compassionate 3-4 sentence summary that:
|
|
|
|
| 232 |
|
| 233 |
Write a compassionate patient summary.""")
|
| 234 |
])
|
| 235 |
+
|
| 236 |
chain = prompt | self.llm
|
| 237 |
+
|
| 238 |
try:
|
| 239 |
driver_names = [kd['biomarker'] for kd in key_drivers[:3]]
|
| 240 |
+
|
| 241 |
response_obj = chain.invoke({
|
| 242 |
"disease": disease,
|
| 243 |
"confidence": confidence,
|
|
|
|
| 246 |
"abnormal": abnormal_count,
|
| 247 |
"drivers": ", ".join(driver_names) if driver_names else "Multiple biomarkers"
|
| 248 |
})
|
| 249 |
+
|
| 250 |
return response_obj.content.strip()
|
| 251 |
+
|
| 252 |
except Exception as e:
|
| 253 |
print(f"Warning: Narrative generation failed: {e}")
|
| 254 |
return f"Your test results suggest {disease} with {confidence:.1%} confidence. {abnormal_count} biomarker(s) are out of normal range. Please consult with a healthcare provider for professional evaluation and guidance."
|
src/biomarker_normalization.py
CHANGED
|
@@ -3,10 +3,9 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Shared biomarker normalization utilities
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
from typing import Dict
|
| 7 |
|
| 8 |
# Normalization map for biomarker aliases to canonical names.
|
| 9 |
-
NORMALIZATION_MAP:
|
| 10 |
# Glucose variations
|
| 11 |
"glucose": "Glucose",
|
| 12 |
"bloodsugar": "Glucose",
|
|
|
|
| 3 |
Shared biomarker normalization utilities
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
|
| 7 |
# Normalization map for biomarker aliases to canonical names.
|
| 8 |
+
NORMALIZATION_MAP: dict[str, str] = {
|
| 9 |
# Glucose variations
|
| 10 |
"glucose": "Glucose",
|
| 11 |
"bloodsugar": "Glucose",
|
src/biomarker_validator.py
CHANGED
|
@@ -5,24 +5,24 @@ Biomarker analysis and validation utilities
|
|
| 5 |
|
| 6 |
import json
|
| 7 |
from pathlib import Path
|
| 8 |
-
|
| 9 |
from src.state import BiomarkerFlag, SafetyAlert
|
| 10 |
|
| 11 |
|
| 12 |
class BiomarkerValidator:
|
| 13 |
"""Validates biomarker values against reference ranges"""
|
| 14 |
-
|
| 15 |
def __init__(self, reference_file: str = "config/biomarker_references.json"):
|
| 16 |
"""Load biomarker reference ranges from JSON file"""
|
| 17 |
ref_path = Path(__file__).parent.parent / reference_file
|
| 18 |
-
with open(ref_path
|
| 19 |
self.references = json.load(f)['biomarkers']
|
| 20 |
-
|
| 21 |
def validate_biomarker(
|
| 22 |
-
self,
|
| 23 |
-
name: str,
|
| 24 |
-
value: float,
|
| 25 |
-
gender:
|
| 26 |
threshold_pct: float = 0.0
|
| 27 |
) -> BiomarkerFlag:
|
| 28 |
"""
|
|
@@ -46,10 +46,10 @@ class BiomarkerValidator:
|
|
| 46 |
reference_range="No reference data available",
|
| 47 |
warning=f"No reference range found for {name}"
|
| 48 |
)
|
| 49 |
-
|
| 50 |
ref = self.references[name]
|
| 51 |
unit = ref['unit']
|
| 52 |
-
|
| 53 |
# Handle gender-specific ranges
|
| 54 |
if ref.get('gender_specific', False) and gender:
|
| 55 |
if gender.lower() in ['male', 'm']:
|
|
@@ -60,16 +60,16 @@ class BiomarkerValidator:
|
|
| 60 |
normal = ref['normal_range']
|
| 61 |
else:
|
| 62 |
normal = ref['normal_range']
|
| 63 |
-
|
| 64 |
min_val = normal.get('min', 0)
|
| 65 |
max_val = normal.get('max', float('inf'))
|
| 66 |
critical_low = ref.get('critical_low')
|
| 67 |
critical_high = ref.get('critical_high')
|
| 68 |
-
|
| 69 |
# Determine status
|
| 70 |
status = "NORMAL"
|
| 71 |
warning = None
|
| 72 |
-
|
| 73 |
# Check critical values first (threshold_pct does not suppress critical alerts)
|
| 74 |
if critical_low and value < critical_low:
|
| 75 |
status = "CRITICAL_LOW"
|
|
@@ -88,9 +88,9 @@ class BiomarkerValidator:
|
|
| 88 |
if deviation > threshold_pct:
|
| 89 |
status = "HIGH"
|
| 90 |
warning = f"{name} is {value} {unit}, above normal range ({min_val}-{max_val} {unit}). {ref['clinical_significance'].get('high', '')}"
|
| 91 |
-
|
| 92 |
reference_range = f"{min_val}-{max_val} {unit}"
|
| 93 |
-
|
| 94 |
return BiomarkerFlag(
|
| 95 |
name=name,
|
| 96 |
value=value,
|
|
@@ -99,13 +99,13 @@ class BiomarkerValidator:
|
|
| 99 |
reference_range=reference_range,
|
| 100 |
warning=warning
|
| 101 |
)
|
| 102 |
-
|
| 103 |
def validate_all(
|
| 104 |
self,
|
| 105 |
-
biomarkers:
|
| 106 |
-
gender:
|
| 107 |
threshold_pct: float = 0.0
|
| 108 |
-
) ->
|
| 109 |
"""
|
| 110 |
Validate all biomarker values.
|
| 111 |
|
|
@@ -119,11 +119,11 @@ class BiomarkerValidator:
|
|
| 119 |
"""
|
| 120 |
flags = []
|
| 121 |
alerts = []
|
| 122 |
-
|
| 123 |
for name, value in biomarkers.items():
|
| 124 |
flag = self.validate_biomarker(name, value, gender, threshold_pct)
|
| 125 |
flags.append(flag)
|
| 126 |
-
|
| 127 |
# Generate safety alerts for critical values
|
| 128 |
if flag.status in ["CRITICAL_LOW", "CRITICAL_HIGH"]:
|
| 129 |
alerts.append(SafetyAlert(
|
|
@@ -140,18 +140,18 @@ class BiomarkerValidator:
|
|
| 140 |
message=flag.warning or f"{name} out of normal range",
|
| 141 |
action="Consult with healthcare provider"
|
| 142 |
))
|
| 143 |
-
|
| 144 |
return flags, alerts
|
| 145 |
-
|
| 146 |
-
def get_biomarker_info(self, name: str) ->
|
| 147 |
"""Get reference information for a biomarker"""
|
| 148 |
return self.references.get(name)
|
| 149 |
|
| 150 |
def expected_biomarker_count(self) -> int:
|
| 151 |
"""Return expected number of biomarkers from reference ranges."""
|
| 152 |
return len(self.references)
|
| 153 |
-
|
| 154 |
-
def get_disease_relevant_biomarkers(self, disease: str) ->
|
| 155 |
"""
|
| 156 |
Get list of biomarkers most relevant to a specific disease.
|
| 157 |
|
|
@@ -159,19 +159,19 @@ class BiomarkerValidator:
|
|
| 159 |
"""
|
| 160 |
disease_map = {
|
| 161 |
"Diabetes": [
|
| 162 |
-
"Glucose", "HbA1c", "Insulin", "BMI",
|
| 163 |
"Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
|
| 164 |
],
|
| 165 |
"Type 2 Diabetes": [
|
| 166 |
-
"Glucose", "HbA1c", "Insulin", "BMI",
|
| 167 |
"Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
|
| 168 |
],
|
| 169 |
"Type 1 Diabetes": [
|
| 170 |
-
"Glucose", "HbA1c", "Insulin", "BMI",
|
| 171 |
"Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
|
| 172 |
],
|
| 173 |
"Anemia": [
|
| 174 |
-
"Hemoglobin", "Red Blood Cells", "Hematocrit",
|
| 175 |
"Mean Corpuscular Volume", "Mean Corpuscular Hemoglobin",
|
| 176 |
"Mean Corpuscular Hemoglobin Concentration"
|
| 177 |
],
|
|
@@ -189,5 +189,5 @@ class BiomarkerValidator:
|
|
| 189 |
"Heart Rate", "BMI"
|
| 190 |
]
|
| 191 |
}
|
| 192 |
-
|
| 193 |
return disease_map.get(disease, [])
|
|
|
|
| 5 |
|
| 6 |
import json
|
| 7 |
from pathlib import Path
|
| 8 |
+
|
| 9 |
from src.state import BiomarkerFlag, SafetyAlert
|
| 10 |
|
| 11 |
|
| 12 |
class BiomarkerValidator:
|
| 13 |
"""Validates biomarker values against reference ranges"""
|
| 14 |
+
|
| 15 |
def __init__(self, reference_file: str = "config/biomarker_references.json"):
|
| 16 |
"""Load biomarker reference ranges from JSON file"""
|
| 17 |
ref_path = Path(__file__).parent.parent / reference_file
|
| 18 |
+
with open(ref_path) as f:
|
| 19 |
self.references = json.load(f)['biomarkers']
|
| 20 |
+
|
| 21 |
def validate_biomarker(
|
| 22 |
+
self,
|
| 23 |
+
name: str,
|
| 24 |
+
value: float,
|
| 25 |
+
gender: str | None = None,
|
| 26 |
threshold_pct: float = 0.0
|
| 27 |
) -> BiomarkerFlag:
|
| 28 |
"""
|
|
|
|
| 46 |
reference_range="No reference data available",
|
| 47 |
warning=f"No reference range found for {name}"
|
| 48 |
)
|
| 49 |
+
|
| 50 |
ref = self.references[name]
|
| 51 |
unit = ref['unit']
|
| 52 |
+
|
| 53 |
# Handle gender-specific ranges
|
| 54 |
if ref.get('gender_specific', False) and gender:
|
| 55 |
if gender.lower() in ['male', 'm']:
|
|
|
|
| 60 |
normal = ref['normal_range']
|
| 61 |
else:
|
| 62 |
normal = ref['normal_range']
|
| 63 |
+
|
| 64 |
min_val = normal.get('min', 0)
|
| 65 |
max_val = normal.get('max', float('inf'))
|
| 66 |
critical_low = ref.get('critical_low')
|
| 67 |
critical_high = ref.get('critical_high')
|
| 68 |
+
|
| 69 |
# Determine status
|
| 70 |
status = "NORMAL"
|
| 71 |
warning = None
|
| 72 |
+
|
| 73 |
# Check critical values first (threshold_pct does not suppress critical alerts)
|
| 74 |
if critical_low and value < critical_low:
|
| 75 |
status = "CRITICAL_LOW"
|
|
|
|
| 88 |
if deviation > threshold_pct:
|
| 89 |
status = "HIGH"
|
| 90 |
warning = f"{name} is {value} {unit}, above normal range ({min_val}-{max_val} {unit}). {ref['clinical_significance'].get('high', '')}"
|
| 91 |
+
|
| 92 |
reference_range = f"{min_val}-{max_val} {unit}"
|
| 93 |
+
|
| 94 |
return BiomarkerFlag(
|
| 95 |
name=name,
|
| 96 |
value=value,
|
|
|
|
| 99 |
reference_range=reference_range,
|
| 100 |
warning=warning
|
| 101 |
)
|
| 102 |
+
|
| 103 |
def validate_all(
|
| 104 |
self,
|
| 105 |
+
biomarkers: dict[str, float],
|
| 106 |
+
gender: str | None = None,
|
| 107 |
threshold_pct: float = 0.0
|
| 108 |
+
) -> tuple[list[BiomarkerFlag], list[SafetyAlert]]:
|
| 109 |
"""
|
| 110 |
Validate all biomarker values.
|
| 111 |
|
|
|
|
| 119 |
"""
|
| 120 |
flags = []
|
| 121 |
alerts = []
|
| 122 |
+
|
| 123 |
for name, value in biomarkers.items():
|
| 124 |
flag = self.validate_biomarker(name, value, gender, threshold_pct)
|
| 125 |
flags.append(flag)
|
| 126 |
+
|
| 127 |
# Generate safety alerts for critical values
|
| 128 |
if flag.status in ["CRITICAL_LOW", "CRITICAL_HIGH"]:
|
| 129 |
alerts.append(SafetyAlert(
|
|
|
|
| 140 |
message=flag.warning or f"{name} out of normal range",
|
| 141 |
action="Consult with healthcare provider"
|
| 142 |
))
|
| 143 |
+
|
| 144 |
return flags, alerts
|
| 145 |
+
|
| 146 |
+
def get_biomarker_info(self, name: str) -> dict | None:
|
| 147 |
"""Get reference information for a biomarker"""
|
| 148 |
return self.references.get(name)
|
| 149 |
|
| 150 |
def expected_biomarker_count(self) -> int:
|
| 151 |
"""Return expected number of biomarkers from reference ranges."""
|
| 152 |
return len(self.references)
|
| 153 |
+
|
| 154 |
+
def get_disease_relevant_biomarkers(self, disease: str) -> list[str]:
|
| 155 |
"""
|
| 156 |
Get list of biomarkers most relevant to a specific disease.
|
| 157 |
|
|
|
|
| 159 |
"""
|
| 160 |
disease_map = {
|
| 161 |
"Diabetes": [
|
| 162 |
+
"Glucose", "HbA1c", "Insulin", "BMI",
|
| 163 |
"Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
|
| 164 |
],
|
| 165 |
"Type 2 Diabetes": [
|
| 166 |
+
"Glucose", "HbA1c", "Insulin", "BMI",
|
| 167 |
"Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
|
| 168 |
],
|
| 169 |
"Type 1 Diabetes": [
|
| 170 |
+
"Glucose", "HbA1c", "Insulin", "BMI",
|
| 171 |
"Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
|
| 172 |
],
|
| 173 |
"Anemia": [
|
| 174 |
+
"Hemoglobin", "Red Blood Cells", "Hematocrit",
|
| 175 |
"Mean Corpuscular Volume", "Mean Corpuscular Hemoglobin",
|
| 176 |
"Mean Corpuscular Hemoglobin Concentration"
|
| 177 |
],
|
|
|
|
| 189 |
"Heart Rate", "BMI"
|
| 190 |
]
|
| 191 |
}
|
| 192 |
+
|
| 193 |
return disease_map.get(disease, [])
|
src/config.py
CHANGED
|
@@ -3,8 +3,9 @@ MediGuard AI RAG-Helper
|
|
| 3 |
Core configuration and SOP (Standard Operating Procedures) definitions
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
-
from typing import Literal, Dict, Any, List, Optional
|
| 8 |
|
| 9 |
|
| 10 |
class ExplanationSOP(BaseModel):
|
|
@@ -13,28 +14,28 @@ class ExplanationSOP(BaseModel):
|
|
| 13 |
This is the 'genome' that controls the entire RAG pipeline behavior.
|
| 14 |
The Outer Loop (Director) will evolve these parameters to improve performance.
|
| 15 |
"""
|
| 16 |
-
|
| 17 |
# === Agent Behavior Parameters ===
|
| 18 |
biomarker_analyzer_threshold: float = Field(
|
| 19 |
default=0.15,
|
| 20 |
description="Percentage deviation from normal range to trigger a warning flag (0.15 = 15%)"
|
| 21 |
)
|
| 22 |
-
|
| 23 |
disease_explainer_k: int = Field(
|
| 24 |
default=5,
|
| 25 |
description="Number of top PDF chunks to retrieve for disease explanation"
|
| 26 |
)
|
| 27 |
-
|
| 28 |
linker_retrieval_k: int = Field(
|
| 29 |
default=3,
|
| 30 |
description="Number of chunks for biomarker-disease linking"
|
| 31 |
)
|
| 32 |
-
|
| 33 |
guideline_retrieval_k: int = Field(
|
| 34 |
default=3,
|
| 35 |
description="Number of chunks for clinical guidelines"
|
| 36 |
)
|
| 37 |
-
|
| 38 |
# === Prompts (Evolvable) ===
|
| 39 |
planner_prompt: str = Field(
|
| 40 |
default="""You are a medical AI coordinator. Create a structured execution plan for analyzing patient biomarkers and explaining a disease prediction.
|
|
@@ -49,7 +50,7 @@ Available specialist agents:
|
|
| 49 |
Output a JSON with key 'plan' containing a list of tasks. Each task must have 'agent', 'task_description', and 'dependencies' keys.""",
|
| 50 |
description="System prompt for the Planner Agent"
|
| 51 |
)
|
| 52 |
-
|
| 53 |
synthesizer_prompt: str = Field(
|
| 54 |
default="""You are a medical communication specialist. Your task is to synthesize findings from specialist agents into a clear, patient-friendly clinical explanation.
|
| 55 |
|
|
@@ -64,39 +65,39 @@ Output a JSON with key 'plan' containing a list of tasks. Each task must have 'a
|
|
| 64 |
Structure your output as specified in the output schema.""",
|
| 65 |
description="System prompt for the Response Synthesizer"
|
| 66 |
)
|
| 67 |
-
|
| 68 |
explainer_detail_level: Literal["concise", "detailed", "comprehensive"] = Field(
|
| 69 |
default="detailed",
|
| 70 |
description="Level of detail in disease mechanism explanations"
|
| 71 |
)
|
| 72 |
-
|
| 73 |
# === Feature Flags ===
|
| 74 |
use_guideline_agent: bool = Field(
|
| 75 |
default=True,
|
| 76 |
description="Whether to retrieve clinical guidelines and recommendations"
|
| 77 |
)
|
| 78 |
-
|
| 79 |
include_alternative_diagnoses: bool = Field(
|
| 80 |
default=True,
|
| 81 |
description="Whether to discuss alternative diagnoses from prediction probabilities"
|
| 82 |
)
|
| 83 |
-
|
| 84 |
require_pdf_citations: bool = Field(
|
| 85 |
default=True,
|
| 86 |
description="Whether to require PDF citations for all claims"
|
| 87 |
)
|
| 88 |
-
|
| 89 |
use_confidence_assessor: bool = Field(
|
| 90 |
default=True,
|
| 91 |
description="Whether to evaluate and report prediction confidence"
|
| 92 |
)
|
| 93 |
-
|
| 94 |
# === Safety Settings ===
|
| 95 |
critical_value_alert_mode: Literal["strict", "moderate", "permissive"] = Field(
|
| 96 |
default="strict",
|
| 97 |
description="Threshold for critical value alerts"
|
| 98 |
)
|
| 99 |
-
|
| 100 |
# === Model Selection ===
|
| 101 |
synthesizer_model: str = Field(
|
| 102 |
default="default",
|
|
|
|
| 3 |
Core configuration and SOP (Standard Operating Procedures) definitions
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
from typing import Literal
|
| 7 |
+
|
| 8 |
from pydantic import BaseModel, Field
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class ExplanationSOP(BaseModel):
|
|
|
|
| 14 |
This is the 'genome' that controls the entire RAG pipeline behavior.
|
| 15 |
The Outer Loop (Director) will evolve these parameters to improve performance.
|
| 16 |
"""
|
| 17 |
+
|
| 18 |
# === Agent Behavior Parameters ===
|
| 19 |
biomarker_analyzer_threshold: float = Field(
|
| 20 |
default=0.15,
|
| 21 |
description="Percentage deviation from normal range to trigger a warning flag (0.15 = 15%)"
|
| 22 |
)
|
| 23 |
+
|
| 24 |
disease_explainer_k: int = Field(
|
| 25 |
default=5,
|
| 26 |
description="Number of top PDF chunks to retrieve for disease explanation"
|
| 27 |
)
|
| 28 |
+
|
| 29 |
linker_retrieval_k: int = Field(
|
| 30 |
default=3,
|
| 31 |
description="Number of chunks for biomarker-disease linking"
|
| 32 |
)
|
| 33 |
+
|
| 34 |
guideline_retrieval_k: int = Field(
|
| 35 |
default=3,
|
| 36 |
description="Number of chunks for clinical guidelines"
|
| 37 |
)
|
| 38 |
+
|
| 39 |
# === Prompts (Evolvable) ===
|
| 40 |
planner_prompt: str = Field(
|
| 41 |
default="""You are a medical AI coordinator. Create a structured execution plan for analyzing patient biomarkers and explaining a disease prediction.
|
|
|
|
| 50 |
Output a JSON with key 'plan' containing a list of tasks. Each task must have 'agent', 'task_description', and 'dependencies' keys.""",
|
| 51 |
description="System prompt for the Planner Agent"
|
| 52 |
)
|
| 53 |
+
|
| 54 |
synthesizer_prompt: str = Field(
|
| 55 |
default="""You are a medical communication specialist. Your task is to synthesize findings from specialist agents into a clear, patient-friendly clinical explanation.
|
| 56 |
|
|
|
|
| 65 |
Structure your output as specified in the output schema.""",
|
| 66 |
description="System prompt for the Response Synthesizer"
|
| 67 |
)
|
| 68 |
+
|
| 69 |
explainer_detail_level: Literal["concise", "detailed", "comprehensive"] = Field(
|
| 70 |
default="detailed",
|
| 71 |
description="Level of detail in disease mechanism explanations"
|
| 72 |
)
|
| 73 |
+
|
| 74 |
# === Feature Flags ===
|
| 75 |
use_guideline_agent: bool = Field(
|
| 76 |
default=True,
|
| 77 |
description="Whether to retrieve clinical guidelines and recommendations"
|
| 78 |
)
|
| 79 |
+
|
| 80 |
include_alternative_diagnoses: bool = Field(
|
| 81 |
default=True,
|
| 82 |
description="Whether to discuss alternative diagnoses from prediction probabilities"
|
| 83 |
)
|
| 84 |
+
|
| 85 |
require_pdf_citations: bool = Field(
|
| 86 |
default=True,
|
| 87 |
description="Whether to require PDF citations for all claims"
|
| 88 |
)
|
| 89 |
+
|
| 90 |
use_confidence_assessor: bool = Field(
|
| 91 |
default=True,
|
| 92 |
description="Whether to evaluate and report prediction confidence"
|
| 93 |
)
|
| 94 |
+
|
| 95 |
# === Safety Settings ===
|
| 96 |
critical_value_alert_mode: Literal["strict", "moderate", "permissive"] = Field(
|
| 97 |
default="strict",
|
| 98 |
description="Threshold for critical value alerts"
|
| 99 |
)
|
| 100 |
+
|
| 101 |
# === Model Selection ===
|
| 102 |
synthesizer_model: str = Field(
|
| 103 |
default="default",
|
src/database.py
CHANGED
|
@@ -6,11 +6,11 @@ Provides SQLAlchemy engine/session factories and the declarative Base.
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
|
|
|
| 9 |
from functools import lru_cache
|
| 10 |
-
from typing import Generator
|
| 11 |
|
| 12 |
from sqlalchemy import create_engine
|
| 13 |
-
from sqlalchemy.orm import
|
| 14 |
|
| 15 |
from src.settings import get_settings
|
| 16 |
|
|
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
+
from collections.abc import Generator
|
| 10 |
from functools import lru_cache
|
|
|
|
| 11 |
|
| 12 |
from sqlalchemy import create_engine
|
| 13 |
+
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
| 14 |
|
| 15 |
from src.settings import get_settings
|
| 16 |
|
src/dependencies.py
CHANGED
|
@@ -6,9 +6,6 @@ Provides factory functions and ``Depends()`` for services used across routers.
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
-
from functools import lru_cache
|
| 10 |
-
|
| 11 |
-
from src.settings import Settings, get_settings
|
| 12 |
from src.services.cache.redis_cache import RedisCache, make_redis_cache
|
| 13 |
from src.services.embeddings.service import EmbeddingService, make_embedding_service
|
| 14 |
from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer
|
|
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
from src.services.cache.redis_cache import RedisCache, make_redis_cache
|
| 10 |
from src.services.embeddings.service import EmbeddingService, make_embedding_service
|
| 11 |
from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer
|
src/evaluation/__init__.py
CHANGED
|
@@ -4,23 +4,23 @@ Exports 5D quality assessment framework components
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from .evaluators import (
|
| 7 |
-
GradedScore,
|
| 8 |
EvaluationResult,
|
| 9 |
-
|
| 10 |
-
evaluate_evidence_grounding,
|
| 11 |
evaluate_actionability,
|
| 12 |
evaluate_clarity,
|
|
|
|
|
|
|
| 13 |
evaluate_safety_completeness,
|
| 14 |
-
run_full_evaluation
|
| 15 |
)
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
-
'GradedScore',
|
| 19 |
'EvaluationResult',
|
| 20 |
-
'
|
| 21 |
-
'evaluate_evidence_grounding',
|
| 22 |
'evaluate_actionability',
|
| 23 |
'evaluate_clarity',
|
|
|
|
|
|
|
| 24 |
'evaluate_safety_completeness',
|
| 25 |
'run_full_evaluation'
|
| 26 |
]
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from .evaluators import (
|
|
|
|
| 7 |
EvaluationResult,
|
| 8 |
+
GradedScore,
|
|
|
|
| 9 |
evaluate_actionability,
|
| 10 |
evaluate_clarity,
|
| 11 |
+
evaluate_clinical_accuracy,
|
| 12 |
+
evaluate_evidence_grounding,
|
| 13 |
evaluate_safety_completeness,
|
| 14 |
+
run_full_evaluation,
|
| 15 |
)
|
| 16 |
|
| 17 |
__all__ = [
|
|
|
|
| 18 |
'EvaluationResult',
|
| 19 |
+
'GradedScore',
|
|
|
|
| 20 |
'evaluate_actionability',
|
| 21 |
'evaluate_clarity',
|
| 22 |
+
'evaluate_clinical_accuracy',
|
| 23 |
+
'evaluate_evidence_grounding',
|
| 24 |
'evaluate_safety_completeness',
|
| 25 |
'run_full_evaluation'
|
| 26 |
]
|
src/evaluation/evaluators.py
CHANGED
|
@@ -22,11 +22,13 @@ Usage:
|
|
| 22 |
print(f"Average score: {result.average_score():.2f}")
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
import os
|
| 26 |
-
from pydantic import BaseModel, Field
|
| 27 |
-
from typing import Dict, Any, List
|
| 28 |
import json
|
|
|
|
|
|
|
|
|
|
| 29 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
|
|
| 30 |
from src.llm_config import get_chat_model
|
| 31 |
|
| 32 |
# Set to True for deterministic evaluation (testing)
|
|
@@ -46,8 +48,8 @@ class EvaluationResult(BaseModel):
|
|
| 46 |
actionability: GradedScore
|
| 47 |
clarity: GradedScore
|
| 48 |
safety_completeness: GradedScore
|
| 49 |
-
|
| 50 |
-
def to_vector(self) ->
|
| 51 |
"""Extract scores as a vector for Pareto analysis"""
|
| 52 |
return [
|
| 53 |
self.clinical_accuracy.score,
|
|
@@ -56,7 +58,7 @@ class EvaluationResult(BaseModel):
|
|
| 56 |
self.clarity.score,
|
| 57 |
self.safety_completeness.score
|
| 58 |
]
|
| 59 |
-
|
| 60 |
def average_score(self) -> float:
|
| 61 |
"""Calculate average of all 5 dimensions"""
|
| 62 |
scores = self.to_vector()
|
|
@@ -65,7 +67,7 @@ class EvaluationResult(BaseModel):
|
|
| 65 |
|
| 66 |
# Evaluator 1: Clinical Accuracy (LLM-as-Judge)
|
| 67 |
def evaluate_clinical_accuracy(
|
| 68 |
-
final_response:
|
| 69 |
pubmed_context: str
|
| 70 |
) -> GradedScore:
|
| 71 |
"""
|
|
@@ -77,13 +79,13 @@ def evaluate_clinical_accuracy(
|
|
| 77 |
# Deterministic mode for testing
|
| 78 |
if DETERMINISTIC_MODE:
|
| 79 |
return _deterministic_clinical_accuracy(final_response, pubmed_context)
|
| 80 |
-
|
| 81 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 82 |
evaluator_llm = get_chat_model(
|
| 83 |
temperature=0.0,
|
| 84 |
json_mode=True
|
| 85 |
)
|
| 86 |
-
|
| 87 |
prompt = ChatPromptTemplate.from_messages([
|
| 88 |
("system", """You are a medical expert evaluating clinical accuracy.
|
| 89 |
|
|
@@ -113,7 +115,7 @@ Respond ONLY with valid JSON in this format:
|
|
| 113 |
{context}
|
| 114 |
""")
|
| 115 |
])
|
| 116 |
-
|
| 117 |
chain = prompt | evaluator_llm
|
| 118 |
result = chain.invoke({
|
| 119 |
"patient_summary": final_response['patient_summary'],
|
|
@@ -121,7 +123,7 @@ Respond ONLY with valid JSON in this format:
|
|
| 121 |
"recommendations": final_response['clinical_recommendations'],
|
| 122 |
"context": pubmed_context
|
| 123 |
})
|
| 124 |
-
|
| 125 |
# Parse JSON response
|
| 126 |
try:
|
| 127 |
content = result.content if isinstance(result.content, str) else str(result.content)
|
|
@@ -134,7 +136,7 @@ Respond ONLY with valid JSON in this format:
|
|
| 134 |
|
| 135 |
# Evaluator 2: Evidence Grounding (Programmatic + LLM)
|
| 136 |
def evaluate_evidence_grounding(
|
| 137 |
-
final_response:
|
| 138 |
) -> GradedScore:
|
| 139 |
"""
|
| 140 |
Checks if all claims are backed by citations.
|
|
@@ -143,32 +145,32 @@ def evaluate_evidence_grounding(
|
|
| 143 |
# Count citations
|
| 144 |
pdf_refs = final_response['prediction_explanation'].get('pdf_references', [])
|
| 145 |
citation_count = len(pdf_refs)
|
| 146 |
-
|
| 147 |
# Check key drivers have evidence
|
| 148 |
key_drivers = final_response['prediction_explanation'].get('key_drivers', [])
|
| 149 |
drivers_with_evidence = sum(1 for d in key_drivers if d.get('evidence'))
|
| 150 |
-
|
| 151 |
# Citation coverage score
|
| 152 |
if len(key_drivers) > 0:
|
| 153 |
coverage = drivers_with_evidence / len(key_drivers)
|
| 154 |
else:
|
| 155 |
coverage = 0.0
|
| 156 |
-
|
| 157 |
# Base score from programmatic checks
|
| 158 |
base_score = min(1.0, citation_count / 5.0) * 0.5 + coverage * 0.5
|
| 159 |
-
|
| 160 |
reasoning = f"""
|
| 161 |
Citations found: {citation_count}
|
| 162 |
Key drivers with evidence: {drivers_with_evidence}/{len(key_drivers)}
|
| 163 |
Citation coverage: {coverage:.1%}
|
| 164 |
"""
|
| 165 |
-
|
| 166 |
return GradedScore(score=base_score, reasoning=reasoning.strip())
|
| 167 |
|
| 168 |
|
| 169 |
# Evaluator 3: Clinical Actionability (LLM-as-Judge)
|
| 170 |
def evaluate_actionability(
|
| 171 |
-
final_response:
|
| 172 |
) -> GradedScore:
|
| 173 |
"""
|
| 174 |
Evaluates if recommendations are actionable and safe.
|
|
@@ -179,13 +181,13 @@ def evaluate_actionability(
|
|
| 179 |
# Deterministic mode for testing
|
| 180 |
if DETERMINISTIC_MODE:
|
| 181 |
return _deterministic_actionability(final_response)
|
| 182 |
-
|
| 183 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 184 |
evaluator_llm = get_chat_model(
|
| 185 |
temperature=0.0,
|
| 186 |
json_mode=True
|
| 187 |
)
|
| 188 |
-
|
| 189 |
prompt = ChatPromptTemplate.from_messages([
|
| 190 |
("system", """You are a clinical care coordinator evaluating actionability.
|
| 191 |
|
|
@@ -216,7 +218,7 @@ Respond ONLY with valid JSON in this format:
|
|
| 216 |
{confidence}
|
| 217 |
""")
|
| 218 |
])
|
| 219 |
-
|
| 220 |
chain = prompt | evaluator_llm
|
| 221 |
recs = final_response['clinical_recommendations']
|
| 222 |
result = chain.invoke({
|
|
@@ -225,7 +227,7 @@ Respond ONLY with valid JSON in this format:
|
|
| 225 |
"monitoring": recs.get('monitoring', []),
|
| 226 |
"confidence": final_response['confidence_assessment']
|
| 227 |
})
|
| 228 |
-
|
| 229 |
# Parse JSON response
|
| 230 |
try:
|
| 231 |
parsed = json.loads(result.content if isinstance(result.content, str) else str(result.content))
|
|
@@ -237,7 +239,7 @@ Respond ONLY with valid JSON in this format:
|
|
| 237 |
|
| 238 |
# Evaluator 4: Explainability Clarity (Programmatic)
|
| 239 |
def evaluate_clarity(
|
| 240 |
-
final_response:
|
| 241 |
) -> GradedScore:
|
| 242 |
"""
|
| 243 |
Measures readability and patient-friendliness.
|
|
@@ -248,16 +250,16 @@ def evaluate_clarity(
|
|
| 248 |
# Deterministic mode for testing
|
| 249 |
if DETERMINISTIC_MODE:
|
| 250 |
return _deterministic_clarity(final_response)
|
| 251 |
-
|
| 252 |
try:
|
| 253 |
import textstat
|
| 254 |
has_textstat = True
|
| 255 |
except ImportError:
|
| 256 |
has_textstat = False
|
| 257 |
-
|
| 258 |
# Get patient narrative
|
| 259 |
narrative = final_response['patient_summary'].get('narrative', '')
|
| 260 |
-
|
| 261 |
if has_textstat:
|
| 262 |
# Calculate readability (Flesch Reading Ease)
|
| 263 |
# Score 60-70 = Standard (8th-9th grade)
|
|
@@ -275,24 +277,24 @@ def evaluate_clarity(
|
|
| 275 |
readability_score = 0.9
|
| 276 |
else:
|
| 277 |
readability_score = max(0.5, 1.0 - (avg_words - 20) * 0.02)
|
| 278 |
-
|
| 279 |
# Medical jargon detection (simple heuristic)
|
| 280 |
medical_terms = [
|
| 281 |
'pathophysiology', 'etiology', 'hemostasis', 'coagulation',
|
| 282 |
'thrombocytopenia', 'erythropoiesis', 'gluconeogenesis'
|
| 283 |
]
|
| 284 |
jargon_count = sum(1 for term in medical_terms if term.lower() in narrative.lower())
|
| 285 |
-
|
| 286 |
# Length check (too short = vague, too long = overwhelming)
|
| 287 |
word_count = len(narrative.split())
|
| 288 |
optimal_length = 50 <= word_count <= 150
|
| 289 |
-
|
| 290 |
# Scoring
|
| 291 |
jargon_penalty = max(0.0, 1.0 - (jargon_count * 0.2))
|
| 292 |
length_score = 1.0 if optimal_length else 0.7
|
| 293 |
-
|
| 294 |
final_score = (readability_score * 0.5 + jargon_penalty * 0.3 + length_score * 0.2)
|
| 295 |
-
|
| 296 |
if has_textstat:
|
| 297 |
reasoning = f"""
|
| 298 |
Flesch Reading Ease: {flesch_score:.1f} (Target: 60-70)
|
|
@@ -307,63 +309,63 @@ def evaluate_clarity(
|
|
| 307 |
Word count: {word_count} (Optimal: 50-150)
|
| 308 |
Note: textstat not available, using fallback metrics
|
| 309 |
"""
|
| 310 |
-
|
| 311 |
return GradedScore(score=final_score, reasoning=reasoning.strip())
|
| 312 |
|
| 313 |
|
| 314 |
# Evaluator 5: Safety & Completeness (Programmatic)
|
| 315 |
def evaluate_safety_completeness(
|
| 316 |
-
final_response:
|
| 317 |
-
biomarkers:
|
| 318 |
) -> GradedScore:
|
| 319 |
"""
|
| 320 |
Checks if all safety concerns are flagged.
|
| 321 |
Programmatic validation.
|
| 322 |
"""
|
| 323 |
from src.biomarker_validator import BiomarkerValidator
|
| 324 |
-
|
| 325 |
# Initialize validator
|
| 326 |
validator = BiomarkerValidator()
|
| 327 |
-
|
| 328 |
# Count out-of-range biomarkers
|
| 329 |
out_of_range_count = 0
|
| 330 |
critical_count = 0
|
| 331 |
-
|
| 332 |
for name, value in biomarkers.items():
|
| 333 |
result = validator.validate_biomarker(name, value) # Fixed: use validate_biomarker instead of validate_single
|
| 334 |
if result.status in ['HIGH', 'LOW', 'CRITICAL_HIGH', 'CRITICAL_LOW']:
|
| 335 |
out_of_range_count += 1
|
| 336 |
if result.status in ['CRITICAL_HIGH', 'CRITICAL_LOW']:
|
| 337 |
critical_count += 1
|
| 338 |
-
|
| 339 |
# Count safety alerts in output
|
| 340 |
safety_alerts = final_response.get('safety_alerts', [])
|
| 341 |
alert_count = len(safety_alerts)
|
| 342 |
critical_alerts = sum(1 for a in safety_alerts if a.get('severity') == 'CRITICAL')
|
| 343 |
-
|
| 344 |
# Check if all critical values have alerts
|
| 345 |
critical_coverage = critical_alerts / critical_count if critical_count > 0 else 1.0
|
| 346 |
-
|
| 347 |
# Check for disclaimer
|
| 348 |
has_disclaimer = 'disclaimer' in final_response.get('metadata', {})
|
| 349 |
-
|
| 350 |
# Check for uncertainty acknowledgment
|
| 351 |
limitations = final_response['confidence_assessment'].get('limitations', [])
|
| 352 |
acknowledges_uncertainty = len(limitations) > 0
|
| 353 |
-
|
| 354 |
# Scoring
|
| 355 |
alert_score = min(1.0, alert_count / max(1, out_of_range_count))
|
| 356 |
critical_score = min(1.0, critical_coverage)
|
| 357 |
disclaimer_score = 1.0 if has_disclaimer else 0.0
|
| 358 |
uncertainty_score = 1.0 if acknowledges_uncertainty else 0.5
|
| 359 |
-
|
| 360 |
final_score = min(1.0, (
|
| 361 |
alert_score * 0.4 +
|
| 362 |
critical_score * 0.3 +
|
| 363 |
disclaimer_score * 0.2 +
|
| 364 |
uncertainty_score * 0.1
|
| 365 |
))
|
| 366 |
-
|
| 367 |
reasoning = f"""
|
| 368 |
Out-of-range biomarkers: {out_of_range_count}
|
| 369 |
Critical values: {critical_count}
|
|
@@ -373,15 +375,15 @@ def evaluate_safety_completeness(
|
|
| 373 |
Has disclaimer: {has_disclaimer}
|
| 374 |
Acknowledges uncertainty: {acknowledges_uncertainty}
|
| 375 |
"""
|
| 376 |
-
|
| 377 |
return GradedScore(score=final_score, reasoning=reasoning.strip())
|
| 378 |
|
| 379 |
|
| 380 |
# Master Evaluation Function
|
| 381 |
def run_full_evaluation(
|
| 382 |
-
final_response:
|
| 383 |
-
agent_outputs:
|
| 384 |
-
biomarkers:
|
| 385 |
) -> EvaluationResult:
|
| 386 |
"""
|
| 387 |
Orchestrates all 5 evaluators and returns complete assessment.
|
|
@@ -389,7 +391,7 @@ def run_full_evaluation(
|
|
| 389 |
print("=" * 70)
|
| 390 |
print("RUNNING 5D EVALUATION GAUNTLET")
|
| 391 |
print("=" * 70)
|
| 392 |
-
|
| 393 |
# Extract context from agent outputs
|
| 394 |
pubmed_context = ""
|
| 395 |
for output in agent_outputs:
|
|
@@ -402,27 +404,27 @@ def run_full_evaluation(
|
|
| 402 |
else:
|
| 403 |
pubmed_context = str(findings)
|
| 404 |
break
|
| 405 |
-
|
| 406 |
# Run all evaluators
|
| 407 |
print("\n1. Evaluating Clinical Accuracy...")
|
| 408 |
clinical_accuracy = evaluate_clinical_accuracy(final_response, pubmed_context)
|
| 409 |
-
|
| 410 |
print("2. Evaluating Evidence Grounding...")
|
| 411 |
evidence_grounding = evaluate_evidence_grounding(final_response)
|
| 412 |
-
|
| 413 |
print("3. Evaluating Clinical Actionability...")
|
| 414 |
actionability = evaluate_actionability(final_response)
|
| 415 |
-
|
| 416 |
print("4. Evaluating Explainability Clarity...")
|
| 417 |
clarity = evaluate_clarity(final_response)
|
| 418 |
-
|
| 419 |
print("5. Evaluating Safety & Completeness...")
|
| 420 |
safety_completeness = evaluate_safety_completeness(final_response, biomarkers)
|
| 421 |
-
|
| 422 |
print("\n" + "=" * 70)
|
| 423 |
print("EVALUATION COMPLETE")
|
| 424 |
print("=" * 70)
|
| 425 |
-
|
| 426 |
return EvaluationResult(
|
| 427 |
clinical_accuracy=clinical_accuracy,
|
| 428 |
evidence_grounding=evidence_grounding,
|
|
@@ -437,26 +439,26 @@ def run_full_evaluation(
|
|
| 437 |
# ---------------------------------------------------------------------------
|
| 438 |
|
| 439 |
def _deterministic_clinical_accuracy(
|
| 440 |
-
final_response:
|
| 441 |
pubmed_context: str
|
| 442 |
) -> GradedScore:
|
| 443 |
"""Heuristic-based clinical accuracy (deterministic)."""
|
| 444 |
score = 0.5
|
| 445 |
reasons = []
|
| 446 |
-
|
| 447 |
# Check if response has expected structure
|
| 448 |
if final_response.get('patient_summary'):
|
| 449 |
score += 0.1
|
| 450 |
reasons.append("Has patient summary")
|
| 451 |
-
|
| 452 |
if final_response.get('prediction_explanation'):
|
| 453 |
score += 0.1
|
| 454 |
reasons.append("Has prediction explanation")
|
| 455 |
-
|
| 456 |
if final_response.get('clinical_recommendations'):
|
| 457 |
score += 0.1
|
| 458 |
reasons.append("Has clinical recommendations")
|
| 459 |
-
|
| 460 |
# Check for citations
|
| 461 |
pred = final_response.get('prediction_explanation', {})
|
| 462 |
if isinstance(pred, dict):
|
|
@@ -464,7 +466,7 @@ def _deterministic_clinical_accuracy(
|
|
| 464 |
if refs:
|
| 465 |
score += min(0.2, len(refs) * 0.05)
|
| 466 |
reasons.append(f"Has {len(refs)} citations")
|
| 467 |
-
|
| 468 |
return GradedScore(
|
| 469 |
score=min(1.0, score),
|
| 470 |
reasoning="[DETERMINISTIC] " + "; ".join(reasons)
|
|
@@ -472,12 +474,12 @@ def _deterministic_clinical_accuracy(
|
|
| 472 |
|
| 473 |
|
| 474 |
def _deterministic_actionability(
|
| 475 |
-
final_response:
|
| 476 |
) -> GradedScore:
|
| 477 |
"""Heuristic-based actionability (deterministic)."""
|
| 478 |
score = 0.5
|
| 479 |
reasons = []
|
| 480 |
-
|
| 481 |
recs = final_response.get('clinical_recommendations', {})
|
| 482 |
if isinstance(recs, dict):
|
| 483 |
if recs.get('immediate_actions'):
|
|
@@ -489,7 +491,7 @@ def _deterministic_actionability(
|
|
| 489 |
if recs.get('monitoring'):
|
| 490 |
score += 0.1
|
| 491 |
reasons.append("Has monitoring recommendations")
|
| 492 |
-
|
| 493 |
return GradedScore(
|
| 494 |
score=min(1.0, score),
|
| 495 |
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Missing recommendations"
|
|
@@ -497,12 +499,12 @@ def _deterministic_actionability(
|
|
| 497 |
|
| 498 |
|
| 499 |
def _deterministic_clarity(
|
| 500 |
-
final_response:
|
| 501 |
) -> GradedScore:
|
| 502 |
"""Heuristic-based clarity (deterministic)."""
|
| 503 |
score = 0.5
|
| 504 |
reasons = []
|
| 505 |
-
|
| 506 |
summary = final_response.get('patient_summary', '')
|
| 507 |
if isinstance(summary, str):
|
| 508 |
word_count = len(summary.split())
|
|
@@ -512,16 +514,16 @@ def _deterministic_clarity(
|
|
| 512 |
elif word_count > 0:
|
| 513 |
score += 0.1
|
| 514 |
reasons.append("Has summary")
|
| 515 |
-
|
| 516 |
# Check for structured output
|
| 517 |
if final_response.get('biomarker_flags'):
|
| 518 |
score += 0.15
|
| 519 |
reasons.append("Has biomarker flags")
|
| 520 |
-
|
| 521 |
if final_response.get('key_findings'):
|
| 522 |
score += 0.15
|
| 523 |
reasons.append("Has key findings")
|
| 524 |
-
|
| 525 |
return GradedScore(
|
| 526 |
score=min(1.0, score),
|
| 527 |
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Limited structure"
|
|
|
|
| 22 |
print(f"Average score: {result.average_score():.2f}")
|
| 23 |
"""
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
import json
|
| 26 |
+
import os
|
| 27 |
+
from typing import Any
|
| 28 |
+
|
| 29 |
from langchain_core.prompts import ChatPromptTemplate
|
| 30 |
+
from pydantic import BaseModel, Field
|
| 31 |
+
|
| 32 |
from src.llm_config import get_chat_model
|
| 33 |
|
| 34 |
# Set to True for deterministic evaluation (testing)
|
|
|
|
| 48 |
actionability: GradedScore
|
| 49 |
clarity: GradedScore
|
| 50 |
safety_completeness: GradedScore
|
| 51 |
+
|
| 52 |
+
def to_vector(self) -> list[float]:
|
| 53 |
"""Extract scores as a vector for Pareto analysis"""
|
| 54 |
return [
|
| 55 |
self.clinical_accuracy.score,
|
|
|
|
| 58 |
self.clarity.score,
|
| 59 |
self.safety_completeness.score
|
| 60 |
]
|
| 61 |
+
|
| 62 |
def average_score(self) -> float:
|
| 63 |
"""Calculate average of all 5 dimensions"""
|
| 64 |
scores = self.to_vector()
|
|
|
|
| 67 |
|
| 68 |
# Evaluator 1: Clinical Accuracy (LLM-as-Judge)
|
| 69 |
def evaluate_clinical_accuracy(
|
| 70 |
+
final_response: dict[str, Any],
|
| 71 |
pubmed_context: str
|
| 72 |
) -> GradedScore:
|
| 73 |
"""
|
|
|
|
| 79 |
# Deterministic mode for testing
|
| 80 |
if DETERMINISTIC_MODE:
|
| 81 |
return _deterministic_clinical_accuracy(final_response, pubmed_context)
|
| 82 |
+
|
| 83 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 84 |
evaluator_llm = get_chat_model(
|
| 85 |
temperature=0.0,
|
| 86 |
json_mode=True
|
| 87 |
)
|
| 88 |
+
|
| 89 |
prompt = ChatPromptTemplate.from_messages([
|
| 90 |
("system", """You are a medical expert evaluating clinical accuracy.
|
| 91 |
|
|
|
|
| 115 |
{context}
|
| 116 |
""")
|
| 117 |
])
|
| 118 |
+
|
| 119 |
chain = prompt | evaluator_llm
|
| 120 |
result = chain.invoke({
|
| 121 |
"patient_summary": final_response['patient_summary'],
|
|
|
|
| 123 |
"recommendations": final_response['clinical_recommendations'],
|
| 124 |
"context": pubmed_context
|
| 125 |
})
|
| 126 |
+
|
| 127 |
# Parse JSON response
|
| 128 |
try:
|
| 129 |
content = result.content if isinstance(result.content, str) else str(result.content)
|
|
|
|
| 136 |
|
| 137 |
# Evaluator 2: Evidence Grounding (Programmatic + LLM)
|
| 138 |
def evaluate_evidence_grounding(
|
| 139 |
+
final_response: dict[str, Any]
|
| 140 |
) -> GradedScore:
|
| 141 |
"""
|
| 142 |
Checks if all claims are backed by citations.
|
|
|
|
| 145 |
# Count citations
|
| 146 |
pdf_refs = final_response['prediction_explanation'].get('pdf_references', [])
|
| 147 |
citation_count = len(pdf_refs)
|
| 148 |
+
|
| 149 |
# Check key drivers have evidence
|
| 150 |
key_drivers = final_response['prediction_explanation'].get('key_drivers', [])
|
| 151 |
drivers_with_evidence = sum(1 for d in key_drivers if d.get('evidence'))
|
| 152 |
+
|
| 153 |
# Citation coverage score
|
| 154 |
if len(key_drivers) > 0:
|
| 155 |
coverage = drivers_with_evidence / len(key_drivers)
|
| 156 |
else:
|
| 157 |
coverage = 0.0
|
| 158 |
+
|
| 159 |
# Base score from programmatic checks
|
| 160 |
base_score = min(1.0, citation_count / 5.0) * 0.5 + coverage * 0.5
|
| 161 |
+
|
| 162 |
reasoning = f"""
|
| 163 |
Citations found: {citation_count}
|
| 164 |
Key drivers with evidence: {drivers_with_evidence}/{len(key_drivers)}
|
| 165 |
Citation coverage: {coverage:.1%}
|
| 166 |
"""
|
| 167 |
+
|
| 168 |
return GradedScore(score=base_score, reasoning=reasoning.strip())
|
| 169 |
|
| 170 |
|
| 171 |
# Evaluator 3: Clinical Actionability (LLM-as-Judge)
|
| 172 |
def evaluate_actionability(
|
| 173 |
+
final_response: dict[str, Any]
|
| 174 |
) -> GradedScore:
|
| 175 |
"""
|
| 176 |
Evaluates if recommendations are actionable and safe.
|
|
|
|
| 181 |
# Deterministic mode for testing
|
| 182 |
if DETERMINISTIC_MODE:
|
| 183 |
return _deterministic_actionability(final_response)
|
| 184 |
+
|
| 185 |
# Use cloud LLM for evaluation (FREE via Groq/Gemini)
|
| 186 |
evaluator_llm = get_chat_model(
|
| 187 |
temperature=0.0,
|
| 188 |
json_mode=True
|
| 189 |
)
|
| 190 |
+
|
| 191 |
prompt = ChatPromptTemplate.from_messages([
|
| 192 |
("system", """You are a clinical care coordinator evaluating actionability.
|
| 193 |
|
|
|
|
| 218 |
{confidence}
|
| 219 |
""")
|
| 220 |
])
|
| 221 |
+
|
| 222 |
chain = prompt | evaluator_llm
|
| 223 |
recs = final_response['clinical_recommendations']
|
| 224 |
result = chain.invoke({
|
|
|
|
| 227 |
"monitoring": recs.get('monitoring', []),
|
| 228 |
"confidence": final_response['confidence_assessment']
|
| 229 |
})
|
| 230 |
+
|
| 231 |
# Parse JSON response
|
| 232 |
try:
|
| 233 |
parsed = json.loads(result.content if isinstance(result.content, str) else str(result.content))
|
|
|
|
| 239 |
|
| 240 |
# Evaluator 4: Explainability Clarity (Programmatic)
|
| 241 |
def evaluate_clarity(
|
| 242 |
+
final_response: dict[str, Any]
|
| 243 |
) -> GradedScore:
|
| 244 |
"""
|
| 245 |
Measures readability and patient-friendliness.
|
|
|
|
| 250 |
# Deterministic mode for testing
|
| 251 |
if DETERMINISTIC_MODE:
|
| 252 |
return _deterministic_clarity(final_response)
|
| 253 |
+
|
| 254 |
try:
|
| 255 |
import textstat
|
| 256 |
has_textstat = True
|
| 257 |
except ImportError:
|
| 258 |
has_textstat = False
|
| 259 |
+
|
| 260 |
# Get patient narrative
|
| 261 |
narrative = final_response['patient_summary'].get('narrative', '')
|
| 262 |
+
|
| 263 |
if has_textstat:
|
| 264 |
# Calculate readability (Flesch Reading Ease)
|
| 265 |
# Score 60-70 = Standard (8th-9th grade)
|
|
|
|
| 277 |
readability_score = 0.9
|
| 278 |
else:
|
| 279 |
readability_score = max(0.5, 1.0 - (avg_words - 20) * 0.02)
|
| 280 |
+
|
| 281 |
# Medical jargon detection (simple heuristic)
|
| 282 |
medical_terms = [
|
| 283 |
'pathophysiology', 'etiology', 'hemostasis', 'coagulation',
|
| 284 |
'thrombocytopenia', 'erythropoiesis', 'gluconeogenesis'
|
| 285 |
]
|
| 286 |
jargon_count = sum(1 for term in medical_terms if term.lower() in narrative.lower())
|
| 287 |
+
|
| 288 |
# Length check (too short = vague, too long = overwhelming)
|
| 289 |
word_count = len(narrative.split())
|
| 290 |
optimal_length = 50 <= word_count <= 150
|
| 291 |
+
|
| 292 |
# Scoring
|
| 293 |
jargon_penalty = max(0.0, 1.0 - (jargon_count * 0.2))
|
| 294 |
length_score = 1.0 if optimal_length else 0.7
|
| 295 |
+
|
| 296 |
final_score = (readability_score * 0.5 + jargon_penalty * 0.3 + length_score * 0.2)
|
| 297 |
+
|
| 298 |
if has_textstat:
|
| 299 |
reasoning = f"""
|
| 300 |
Flesch Reading Ease: {flesch_score:.1f} (Target: 60-70)
|
|
|
|
| 309 |
Word count: {word_count} (Optimal: 50-150)
|
| 310 |
Note: textstat not available, using fallback metrics
|
| 311 |
"""
|
| 312 |
+
|
| 313 |
return GradedScore(score=final_score, reasoning=reasoning.strip())
|
| 314 |
|
| 315 |
|
| 316 |
# Evaluator 5: Safety & Completeness (Programmatic)
|
| 317 |
def evaluate_safety_completeness(
|
| 318 |
+
final_response: dict[str, Any],
|
| 319 |
+
biomarkers: dict[str, float]
|
| 320 |
) -> GradedScore:
|
| 321 |
"""
|
| 322 |
Checks if all safety concerns are flagged.
|
| 323 |
Programmatic validation.
|
| 324 |
"""
|
| 325 |
from src.biomarker_validator import BiomarkerValidator
|
| 326 |
+
|
| 327 |
# Initialize validator
|
| 328 |
validator = BiomarkerValidator()
|
| 329 |
+
|
| 330 |
# Count out-of-range biomarkers
|
| 331 |
out_of_range_count = 0
|
| 332 |
critical_count = 0
|
| 333 |
+
|
| 334 |
for name, value in biomarkers.items():
|
| 335 |
result = validator.validate_biomarker(name, value) # Fixed: use validate_biomarker instead of validate_single
|
| 336 |
if result.status in ['HIGH', 'LOW', 'CRITICAL_HIGH', 'CRITICAL_LOW']:
|
| 337 |
out_of_range_count += 1
|
| 338 |
if result.status in ['CRITICAL_HIGH', 'CRITICAL_LOW']:
|
| 339 |
critical_count += 1
|
| 340 |
+
|
| 341 |
# Count safety alerts in output
|
| 342 |
safety_alerts = final_response.get('safety_alerts', [])
|
| 343 |
alert_count = len(safety_alerts)
|
| 344 |
critical_alerts = sum(1 for a in safety_alerts if a.get('severity') == 'CRITICAL')
|
| 345 |
+
|
| 346 |
# Check if all critical values have alerts
|
| 347 |
critical_coverage = critical_alerts / critical_count if critical_count > 0 else 1.0
|
| 348 |
+
|
| 349 |
# Check for disclaimer
|
| 350 |
has_disclaimer = 'disclaimer' in final_response.get('metadata', {})
|
| 351 |
+
|
| 352 |
# Check for uncertainty acknowledgment
|
| 353 |
limitations = final_response['confidence_assessment'].get('limitations', [])
|
| 354 |
acknowledges_uncertainty = len(limitations) > 0
|
| 355 |
+
|
| 356 |
# Scoring
|
| 357 |
alert_score = min(1.0, alert_count / max(1, out_of_range_count))
|
| 358 |
critical_score = min(1.0, critical_coverage)
|
| 359 |
disclaimer_score = 1.0 if has_disclaimer else 0.0
|
| 360 |
uncertainty_score = 1.0 if acknowledges_uncertainty else 0.5
|
| 361 |
+
|
| 362 |
final_score = min(1.0, (
|
| 363 |
alert_score * 0.4 +
|
| 364 |
critical_score * 0.3 +
|
| 365 |
disclaimer_score * 0.2 +
|
| 366 |
uncertainty_score * 0.1
|
| 367 |
))
|
| 368 |
+
|
| 369 |
reasoning = f"""
|
| 370 |
Out-of-range biomarkers: {out_of_range_count}
|
| 371 |
Critical values: {critical_count}
|
|
|
|
| 375 |
Has disclaimer: {has_disclaimer}
|
| 376 |
Acknowledges uncertainty: {acknowledges_uncertainty}
|
| 377 |
"""
|
| 378 |
+
|
| 379 |
return GradedScore(score=final_score, reasoning=reasoning.strip())
|
| 380 |
|
| 381 |
|
| 382 |
# Master Evaluation Function
|
| 383 |
def run_full_evaluation(
|
| 384 |
+
final_response: dict[str, Any],
|
| 385 |
+
agent_outputs: list[Any],
|
| 386 |
+
biomarkers: dict[str, float]
|
| 387 |
) -> EvaluationResult:
|
| 388 |
"""
|
| 389 |
Orchestrates all 5 evaluators and returns complete assessment.
|
|
|
|
| 391 |
print("=" * 70)
|
| 392 |
print("RUNNING 5D EVALUATION GAUNTLET")
|
| 393 |
print("=" * 70)
|
| 394 |
+
|
| 395 |
# Extract context from agent outputs
|
| 396 |
pubmed_context = ""
|
| 397 |
for output in agent_outputs:
|
|
|
|
| 404 |
else:
|
| 405 |
pubmed_context = str(findings)
|
| 406 |
break
|
| 407 |
+
|
| 408 |
# Run all evaluators
|
| 409 |
print("\n1. Evaluating Clinical Accuracy...")
|
| 410 |
clinical_accuracy = evaluate_clinical_accuracy(final_response, pubmed_context)
|
| 411 |
+
|
| 412 |
print("2. Evaluating Evidence Grounding...")
|
| 413 |
evidence_grounding = evaluate_evidence_grounding(final_response)
|
| 414 |
+
|
| 415 |
print("3. Evaluating Clinical Actionability...")
|
| 416 |
actionability = evaluate_actionability(final_response)
|
| 417 |
+
|
| 418 |
print("4. Evaluating Explainability Clarity...")
|
| 419 |
clarity = evaluate_clarity(final_response)
|
| 420 |
+
|
| 421 |
print("5. Evaluating Safety & Completeness...")
|
| 422 |
safety_completeness = evaluate_safety_completeness(final_response, biomarkers)
|
| 423 |
+
|
| 424 |
print("\n" + "=" * 70)
|
| 425 |
print("EVALUATION COMPLETE")
|
| 426 |
print("=" * 70)
|
| 427 |
+
|
| 428 |
return EvaluationResult(
|
| 429 |
clinical_accuracy=clinical_accuracy,
|
| 430 |
evidence_grounding=evidence_grounding,
|
|
|
|
| 439 |
# ---------------------------------------------------------------------------
|
| 440 |
|
| 441 |
def _deterministic_clinical_accuracy(
|
| 442 |
+
final_response: dict[str, Any],
|
| 443 |
pubmed_context: str
|
| 444 |
) -> GradedScore:
|
| 445 |
"""Heuristic-based clinical accuracy (deterministic)."""
|
| 446 |
score = 0.5
|
| 447 |
reasons = []
|
| 448 |
+
|
| 449 |
# Check if response has expected structure
|
| 450 |
if final_response.get('patient_summary'):
|
| 451 |
score += 0.1
|
| 452 |
reasons.append("Has patient summary")
|
| 453 |
+
|
| 454 |
if final_response.get('prediction_explanation'):
|
| 455 |
score += 0.1
|
| 456 |
reasons.append("Has prediction explanation")
|
| 457 |
+
|
| 458 |
if final_response.get('clinical_recommendations'):
|
| 459 |
score += 0.1
|
| 460 |
reasons.append("Has clinical recommendations")
|
| 461 |
+
|
| 462 |
# Check for citations
|
| 463 |
pred = final_response.get('prediction_explanation', {})
|
| 464 |
if isinstance(pred, dict):
|
|
|
|
| 466 |
if refs:
|
| 467 |
score += min(0.2, len(refs) * 0.05)
|
| 468 |
reasons.append(f"Has {len(refs)} citations")
|
| 469 |
+
|
| 470 |
return GradedScore(
|
| 471 |
score=min(1.0, score),
|
| 472 |
reasoning="[DETERMINISTIC] " + "; ".join(reasons)
|
|
|
|
| 474 |
|
| 475 |
|
| 476 |
def _deterministic_actionability(
|
| 477 |
+
final_response: dict[str, Any]
|
| 478 |
) -> GradedScore:
|
| 479 |
"""Heuristic-based actionability (deterministic)."""
|
| 480 |
score = 0.5
|
| 481 |
reasons = []
|
| 482 |
+
|
| 483 |
recs = final_response.get('clinical_recommendations', {})
|
| 484 |
if isinstance(recs, dict):
|
| 485 |
if recs.get('immediate_actions'):
|
|
|
|
| 491 |
if recs.get('monitoring'):
|
| 492 |
score += 0.1
|
| 493 |
reasons.append("Has monitoring recommendations")
|
| 494 |
+
|
| 495 |
return GradedScore(
|
| 496 |
score=min(1.0, score),
|
| 497 |
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Missing recommendations"
|
|
|
|
| 499 |
|
| 500 |
|
| 501 |
def _deterministic_clarity(
|
| 502 |
+
final_response: dict[str, Any]
|
| 503 |
) -> GradedScore:
|
| 504 |
"""Heuristic-based clarity (deterministic)."""
|
| 505 |
score = 0.5
|
| 506 |
reasons = []
|
| 507 |
+
|
| 508 |
summary = final_response.get('patient_summary', '')
|
| 509 |
if isinstance(summary, str):
|
| 510 |
word_count = len(summary.split())
|
|
|
|
| 514 |
elif word_count > 0:
|
| 515 |
score += 0.1
|
| 516 |
reasons.append("Has summary")
|
| 517 |
+
|
| 518 |
# Check for structured output
|
| 519 |
if final_response.get('biomarker_flags'):
|
| 520 |
score += 0.15
|
| 521 |
reasons.append("Has biomarker flags")
|
| 522 |
+
|
| 523 |
if final_response.get('key_findings'):
|
| 524 |
score += 0.15
|
| 525 |
reasons.append("Has key findings")
|
| 526 |
+
|
| 527 |
return GradedScore(
|
| 528 |
score=min(1.0, score),
|
| 529 |
reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Limited structure"
|
src/exceptions.py
CHANGED
|
@@ -6,15 +6,14 @@ Each service layer raises its own exception type so callers can handle
|
|
| 6 |
failures precisely without leaking implementation details.
|
| 7 |
"""
|
| 8 |
|
| 9 |
-
from typing import Any
|
| 10 |
-
|
| 11 |
|
| 12 |
# ── Base ──────────────────────────────────────────────────────────────────────
|
| 13 |
|
| 14 |
class MediGuardError(Exception):
|
| 15 |
"""Root exception for the entire MediGuard AI application."""
|
| 16 |
|
| 17 |
-
def __init__(self, message: str = "", *, details:
|
| 18 |
self.details = details or {}
|
| 19 |
super().__init__(message)
|
| 20 |
|
|
|
|
| 6 |
failures precisely without leaking implementation details.
|
| 7 |
"""
|
| 8 |
|
| 9 |
+
from typing import Any
|
|
|
|
| 10 |
|
| 11 |
# ── Base ──────────────────────────────────────────────────────────────────────
|
| 12 |
|
| 13 |
class MediGuardError(Exception):
|
| 14 |
"""Root exception for the entire MediGuard AI application."""
|
| 15 |
|
| 16 |
+
def __init__(self, message: str = "", *, details: dict[str, Any] | None = None):
|
| 17 |
self.details = details or {}
|
| 18 |
super().__init__(message)
|
| 19 |
|
src/gradio_app.py
CHANGED
|
@@ -17,15 +17,33 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
API_BASE = os.getenv("MEDIGUARD_API_URL", "http://localhost:8000")
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
"""Call the /ask endpoint."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
-
with httpx.
|
| 24 |
-
resp = client.post(f"{API_BASE}/ask", json={"question": question})
|
| 25 |
resp.raise_for_status()
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
except Exception as exc:
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def _call_analyze(biomarkers_json: str) -> str:
|
|
@@ -47,7 +65,7 @@ def _call_analyze(biomarkers_json: str) -> str:
|
|
| 47 |
return f"Error: {exc}"
|
| 48 |
|
| 49 |
|
| 50 |
-
def launch_gradio(share: bool = False) -> None:
|
| 51 |
"""Launch the Gradio interface."""
|
| 52 |
try:
|
| 53 |
import gradio as gr
|
|
@@ -62,14 +80,27 @@ def launch_gradio(share: bool = False) -> None:
|
|
| 62 |
)
|
| 63 |
|
| 64 |
with gr.Tab("Ask a Question"):
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
with gr.Tab("Analyze Biomarkers"):
|
| 75 |
bio_input = gr.Textbox(
|
|
@@ -82,20 +113,28 @@ def launch_gradio(share: bool = False) -> None:
|
|
| 82 |
analyze_btn.click(fn=_call_analyze, inputs=bio_input, outputs=analysis_output)
|
| 83 |
|
| 84 |
with gr.Tab("Search Knowledge Base"):
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
search_btn = gr.Button("Search", variant="primary")
|
| 91 |
search_output = gr.Textbox(label="Results", lines=15, interactive=False)
|
| 92 |
|
| 93 |
-
def _call_search(query: str) -> str:
|
| 94 |
try:
|
| 95 |
with httpx.Client(timeout=30.0) as client:
|
| 96 |
resp = client.post(
|
| 97 |
f"{API_BASE}/search",
|
| 98 |
-
json={"query": query, "top_k": 5, "mode":
|
| 99 |
)
|
| 100 |
resp.raise_for_status()
|
| 101 |
data = resp.json()
|
|
@@ -112,10 +151,11 @@ def launch_gradio(share: bool = False) -> None:
|
|
| 112 |
except Exception as exc:
|
| 113 |
return f"Error: {exc}"
|
| 114 |
|
| 115 |
-
search_btn.click(fn=_call_search, inputs=search_input, outputs=search_output)
|
| 116 |
|
| 117 |
-
demo.launch(server_name="0.0.0.0", server_port=
|
| 118 |
|
| 119 |
|
| 120 |
if __name__ == "__main__":
|
| 121 |
-
|
|
|
|
|
|
| 17 |
API_BASE = os.getenv("MEDIGUARD_API_URL", "http://localhost:8000")
|
| 18 |
|
| 19 |
|
| 20 |
+
def ask_stream(question: str, history: list, model: str):
|
| 21 |
+
"""Call the /ask/stream endpoint."""
|
| 22 |
+
history = history or []
|
| 23 |
+
if not question.strip():
|
| 24 |
+
yield "", history
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
history.append((question, ""))
|
| 28 |
+
|
| 29 |
try:
|
| 30 |
+
with httpx.stream("POST", f"{API_BASE}/ask/stream", json={"question": question}, timeout=60.0) as resp:
|
|
|
|
| 31 |
resp.raise_for_status()
|
| 32 |
+
for line in resp.iter_lines():
|
| 33 |
+
if line.startswith("data: "):
|
| 34 |
+
content = line[6:]
|
| 35 |
+
if content == "[DONE]":
|
| 36 |
+
break
|
| 37 |
+
try:
|
| 38 |
+
data = json.loads(content)
|
| 39 |
+
current_bot_msg = history[-1][1] + data.get("text", "")
|
| 40 |
+
history[-1] = (question, current_bot_msg)
|
| 41 |
+
yield "", history
|
| 42 |
+
except Exception as trace_exc:
|
| 43 |
+
logger.debug("Failed to parse streaming chunk: %s", trace_exc)
|
| 44 |
except Exception as exc:
|
| 45 |
+
history[-1] = (question, f"Error: {exc}")
|
| 46 |
+
yield "", history
|
| 47 |
|
| 48 |
|
| 49 |
def _call_analyze(biomarkers_json: str) -> str:
|
|
|
|
| 65 |
return f"Error: {exc}"
|
| 66 |
|
| 67 |
|
| 68 |
+
def launch_gradio(share: bool = False, server_port: int = 7860) -> None:
|
| 69 |
"""Launch the Gradio interface."""
|
| 70 |
try:
|
| 71 |
import gradio as gr
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
with gr.Tab("Ask a Question"):
|
| 83 |
+
with gr.Row():
|
| 84 |
+
with gr.Column(scale=3):
|
| 85 |
+
chatbot = gr.Chatbot(label="Medical Q&A History", height=400)
|
| 86 |
+
question_input = gr.Textbox(
|
| 87 |
+
label="Medical Question",
|
| 88 |
+
placeholder="e.g., What does a high HbA1c level indicate?",
|
| 89 |
+
lines=2,
|
| 90 |
+
)
|
| 91 |
+
with gr.Row():
|
| 92 |
+
ask_btn = gr.Button("Ask (Streaming)", variant="primary")
|
| 93 |
+
clear_btn = gr.Button("Clear History")
|
| 94 |
+
|
| 95 |
+
with gr.Column(scale=1):
|
| 96 |
+
model_selector = gr.Dropdown(
|
| 97 |
+
choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
|
| 98 |
+
value="llama-3.3-70b-versatile",
|
| 99 |
+
label="LLM Provider/Model"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
ask_btn.click(fn=ask_stream, inputs=[question_input, chatbot, model_selector], outputs=[question_input, chatbot])
|
| 103 |
+
clear_btn.click(fn=lambda: ([], ""), outputs=[chatbot, question_input])
|
| 104 |
|
| 105 |
with gr.Tab("Analyze Biomarkers"):
|
| 106 |
bio_input = gr.Textbox(
|
|
|
|
| 113 |
analyze_btn.click(fn=_call_analyze, inputs=bio_input, outputs=analysis_output)
|
| 114 |
|
| 115 |
with gr.Tab("Search Knowledge Base"):
|
| 116 |
+
with gr.Row():
|
| 117 |
+
search_input = gr.Textbox(
|
| 118 |
+
label="Search Query",
|
| 119 |
+
placeholder="e.g., diabetes management guidelines",
|
| 120 |
+
lines=2,
|
| 121 |
+
scale=3
|
| 122 |
+
)
|
| 123 |
+
search_mode = gr.Radio(
|
| 124 |
+
choices=["hybrid", "bm25", "vector"],
|
| 125 |
+
value="hybrid",
|
| 126 |
+
label="Search Strategy",
|
| 127 |
+
scale=1
|
| 128 |
+
)
|
| 129 |
search_btn = gr.Button("Search", variant="primary")
|
| 130 |
search_output = gr.Textbox(label="Results", lines=15, interactive=False)
|
| 131 |
|
| 132 |
+
def _call_search(query: str, mode: str) -> str:
|
| 133 |
try:
|
| 134 |
with httpx.Client(timeout=30.0) as client:
|
| 135 |
resp = client.post(
|
| 136 |
f"{API_BASE}/search",
|
| 137 |
+
json={"query": query, "top_k": 5, "mode": mode},
|
| 138 |
)
|
| 139 |
resp.raise_for_status()
|
| 140 |
data = resp.json()
|
|
|
|
| 151 |
except Exception as exc:
|
| 152 |
return f"Error: {exc}"
|
| 153 |
|
| 154 |
+
search_btn.click(fn=_call_search, inputs=[search_input, search_mode], outputs=search_output)
|
| 155 |
|
| 156 |
+
demo.launch(server_name="0.0.0.0", server_port=server_port, share=share)
|
| 157 |
|
| 158 |
|
| 159 |
if __name__ == "__main__":
|
| 160 |
+
port = int(os.environ.get("GRADIO_PORT", 7860))
|
| 161 |
+
launch_gradio(server_port=port)
|
src/llm_config.py
CHANGED
|
@@ -14,7 +14,8 @@ Environment Variables (supports both naming conventions):
|
|
| 14 |
|
| 15 |
import os
|
| 16 |
import threading
|
| 17 |
-
from typing import Literal
|
|
|
|
| 18 |
from dotenv import load_dotenv
|
| 19 |
|
| 20 |
# Load environment variables
|
|
@@ -64,8 +65,8 @@ DEFAULT_LLM_PROVIDER = get_default_llm_provider()
|
|
| 64 |
|
| 65 |
|
| 66 |
def get_chat_model(
|
| 67 |
-
provider:
|
| 68 |
-
model:
|
| 69 |
temperature: float = 0.0,
|
| 70 |
json_mode: bool = False
|
| 71 |
):
|
|
@@ -83,61 +84,61 @@ def get_chat_model(
|
|
| 83 |
"""
|
| 84 |
# Use dynamic lookup to get current provider from environment
|
| 85 |
provider = provider or get_default_llm_provider()
|
| 86 |
-
|
| 87 |
if provider == "groq":
|
| 88 |
from langchain_groq import ChatGroq
|
| 89 |
-
|
| 90 |
api_key = get_groq_api_key()
|
| 91 |
if not api_key:
|
| 92 |
raise ValueError(
|
| 93 |
"GROQ_API_KEY not found in environment.\n"
|
| 94 |
"Get your FREE API key at: https://console.groq.com/keys"
|
| 95 |
)
|
| 96 |
-
|
| 97 |
# Use model from environment or default
|
| 98 |
model = model or get_groq_model()
|
| 99 |
-
|
| 100 |
return ChatGroq(
|
| 101 |
model=model,
|
| 102 |
temperature=temperature,
|
| 103 |
api_key=api_key,
|
| 104 |
model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {}
|
| 105 |
)
|
| 106 |
-
|
| 107 |
elif provider == "gemini":
|
| 108 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 109 |
-
|
| 110 |
api_key = get_google_api_key()
|
| 111 |
if not api_key:
|
| 112 |
raise ValueError(
|
| 113 |
"GOOGLE_API_KEY not found in environment.\n"
|
| 114 |
"Get your FREE API key at: https://aistudio.google.com/app/apikey"
|
| 115 |
)
|
| 116 |
-
|
| 117 |
# Use model from environment or default
|
| 118 |
model = model or get_gemini_model()
|
| 119 |
-
|
| 120 |
return ChatGoogleGenerativeAI(
|
| 121 |
model=model,
|
| 122 |
temperature=temperature,
|
| 123 |
google_api_key=api_key,
|
| 124 |
convert_system_message_to_human=True
|
| 125 |
)
|
| 126 |
-
|
| 127 |
elif provider == "ollama":
|
| 128 |
try:
|
| 129 |
from langchain_ollama import ChatOllama
|
| 130 |
except ImportError:
|
| 131 |
from langchain_community.chat_models import ChatOllama
|
| 132 |
-
|
| 133 |
model = model or "llama3.1:8b"
|
| 134 |
-
|
| 135 |
return ChatOllama(
|
| 136 |
model=model,
|
| 137 |
temperature=temperature,
|
| 138 |
format='json' if json_mode else None
|
| 139 |
)
|
| 140 |
-
|
| 141 |
else:
|
| 142 |
raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
|
| 143 |
|
|
@@ -147,7 +148,7 @@ def get_embedding_provider() -> str:
|
|
| 147 |
return _get_env_with_fallback("EMBEDDING_PROVIDER", "EMBEDDING__PROVIDER", "huggingface")
|
| 148 |
|
| 149 |
|
| 150 |
-
def get_embedding_model(provider:
|
| 151 |
"""
|
| 152 |
Get embedding model for vector search.
|
| 153 |
|
|
@@ -162,7 +163,7 @@ def get_embedding_model(provider: Optional[Literal["jina", "google", "huggingfac
|
|
| 162 |
which has automatic fallback chain: Jina → Google → HuggingFace.
|
| 163 |
"""
|
| 164 |
provider = provider or get_embedding_provider()
|
| 165 |
-
|
| 166 |
if provider == "jina":
|
| 167 |
# Try Jina AI embeddings first (high quality, 1024d)
|
| 168 |
jina_key = _get_env_with_fallback("JINA_API_KEY", "EMBEDDING__JINA_API_KEY", "")
|
|
@@ -178,15 +179,15 @@ def get_embedding_model(provider: Optional[Literal["jina", "google", "huggingfac
|
|
| 178 |
else:
|
| 179 |
print("WARN: JINA_API_KEY not found. Falling back to Google embeddings.")
|
| 180 |
return get_embedding_model("google")
|
| 181 |
-
|
| 182 |
elif provider == "google":
|
| 183 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 184 |
-
|
| 185 |
api_key = get_google_api_key()
|
| 186 |
if not api_key:
|
| 187 |
print("WARN: GOOGLE_API_KEY not found. Falling back to HuggingFace embeddings.")
|
| 188 |
return get_embedding_model("huggingface")
|
| 189 |
-
|
| 190 |
try:
|
| 191 |
return GoogleGenerativeAIEmbeddings(
|
| 192 |
model="models/text-embedding-004",
|
|
@@ -196,33 +197,33 @@ def get_embedding_model(provider: Optional[Literal["jina", "google", "huggingfac
|
|
| 196 |
print(f"WARN: Google embeddings failed: {e}")
|
| 197 |
print("INFO: Falling back to HuggingFace embeddings...")
|
| 198 |
return get_embedding_model("huggingface")
|
| 199 |
-
|
| 200 |
elif provider == "huggingface":
|
| 201 |
try:
|
| 202 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 203 |
except ImportError:
|
| 204 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 205 |
-
|
| 206 |
return HuggingFaceEmbeddings(
|
| 207 |
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
| 208 |
)
|
| 209 |
-
|
| 210 |
elif provider == "ollama":
|
| 211 |
try:
|
| 212 |
from langchain_ollama import OllamaEmbeddings
|
| 213 |
except ImportError:
|
| 214 |
from langchain_community.embeddings import OllamaEmbeddings
|
| 215 |
-
|
| 216 |
return OllamaEmbeddings(model="nomic-embed-text")
|
| 217 |
-
|
| 218 |
else:
|
| 219 |
raise ValueError(f"Unknown embedding provider: {provider}")
|
| 220 |
|
| 221 |
|
| 222 |
class LLMConfig:
|
| 223 |
"""Central configuration for all LLM models"""
|
| 224 |
-
|
| 225 |
-
def __init__(self, provider:
|
| 226 |
"""
|
| 227 |
Initialize all model clients.
|
| 228 |
|
|
@@ -236,7 +237,7 @@ class LLMConfig:
|
|
| 236 |
self._initialized = False
|
| 237 |
self._initialized_provider = None # Track which provider was initialized
|
| 238 |
self._lock = threading.Lock()
|
| 239 |
-
|
| 240 |
# Lazy-initialized model instances
|
| 241 |
self._planner = None
|
| 242 |
self._analyzer = None
|
|
@@ -245,15 +246,15 @@ class LLMConfig:
|
|
| 245 |
self._synthesizer_8b = None
|
| 246 |
self._director = None
|
| 247 |
self._embedding_model = None
|
| 248 |
-
|
| 249 |
if not lazy:
|
| 250 |
self._initialize_models()
|
| 251 |
-
|
| 252 |
@property
|
| 253 |
def provider(self) -> str:
|
| 254 |
"""Get current provider (dynamic lookup if not explicitly set)."""
|
| 255 |
return self._explicit_provider or get_default_llm_provider()
|
| 256 |
-
|
| 257 |
def _check_provider_change(self):
|
| 258 |
"""Check if provider changed and reinitialize if needed."""
|
| 259 |
current = self.provider
|
|
@@ -266,120 +267,120 @@ class LLMConfig:
|
|
| 266 |
self._synthesizer_7b = None
|
| 267 |
self._synthesizer_8b = None
|
| 268 |
self._director = None
|
| 269 |
-
|
| 270 |
def _initialize_models(self):
|
| 271 |
"""Initialize all model clients (called on first use if lazy)"""
|
| 272 |
self._check_provider_change()
|
| 273 |
-
|
| 274 |
if self._initialized:
|
| 275 |
return
|
| 276 |
-
|
| 277 |
with self._lock:
|
| 278 |
# Double-checked locking
|
| 279 |
if self._initialized:
|
| 280 |
return
|
| 281 |
-
|
| 282 |
print(f"Initializing LLM models with provider: {self.provider.upper()}")
|
| 283 |
-
|
| 284 |
# Fast model for structured tasks (planning, analysis)
|
| 285 |
self._planner = get_chat_model(
|
| 286 |
provider=self.provider,
|
| 287 |
temperature=0.0,
|
| 288 |
json_mode=True
|
| 289 |
)
|
| 290 |
-
|
| 291 |
# Fast model for biomarker analysis and quick tasks
|
| 292 |
self._analyzer = get_chat_model(
|
| 293 |
provider=self.provider,
|
| 294 |
temperature=0.0
|
| 295 |
)
|
| 296 |
-
|
| 297 |
# Medium model for RAG retrieval and explanation
|
| 298 |
self._explainer = get_chat_model(
|
| 299 |
provider=self.provider,
|
| 300 |
temperature=0.2
|
| 301 |
)
|
| 302 |
-
|
| 303 |
# Configurable synthesizers
|
| 304 |
self._synthesizer_7b = get_chat_model(
|
| 305 |
provider=self.provider,
|
| 306 |
temperature=0.2
|
| 307 |
)
|
| 308 |
-
|
| 309 |
self._synthesizer_8b = get_chat_model(
|
| 310 |
provider=self.provider,
|
| 311 |
temperature=0.2
|
| 312 |
)
|
| 313 |
-
|
| 314 |
# Director for Outer Loop
|
| 315 |
self._director = get_chat_model(
|
| 316 |
provider=self.provider,
|
| 317 |
temperature=0.0,
|
| 318 |
json_mode=True
|
| 319 |
)
|
| 320 |
-
|
| 321 |
-
# Embedding model for RAG
|
| 322 |
self._embedding_model = get_embedding_model()
|
| 323 |
-
|
| 324 |
self._initialized = True
|
| 325 |
self._initialized_provider = self.provider
|
| 326 |
-
|
| 327 |
@property
|
| 328 |
def planner(self):
|
| 329 |
self._initialize_models()
|
| 330 |
return self._planner
|
| 331 |
-
|
| 332 |
@property
|
| 333 |
def analyzer(self):
|
| 334 |
self._initialize_models()
|
| 335 |
return self._analyzer
|
| 336 |
-
|
| 337 |
@property
|
| 338 |
def explainer(self):
|
| 339 |
self._initialize_models()
|
| 340 |
return self._explainer
|
| 341 |
-
|
| 342 |
@property
|
| 343 |
def synthesizer_7b(self):
|
| 344 |
self._initialize_models()
|
| 345 |
return self._synthesizer_7b
|
| 346 |
-
|
| 347 |
@property
|
| 348 |
def synthesizer_8b(self):
|
| 349 |
self._initialize_models()
|
| 350 |
return self._synthesizer_8b
|
| 351 |
-
|
| 352 |
@property
|
| 353 |
def director(self):
|
| 354 |
self._initialize_models()
|
| 355 |
return self._director
|
| 356 |
-
|
| 357 |
@property
|
| 358 |
def embedding_model(self):
|
| 359 |
self._initialize_models()
|
| 360 |
return self._embedding_model
|
| 361 |
-
|
| 362 |
-
def get_synthesizer(self, model_name:
|
| 363 |
"""Get synthesizer model (for backward compatibility)"""
|
| 364 |
if model_name:
|
| 365 |
return get_chat_model(provider=self.provider, model=model_name, temperature=0.2)
|
| 366 |
return self.synthesizer_8b
|
| 367 |
-
|
| 368 |
def print_config(self):
|
| 369 |
"""Print current LLM configuration"""
|
| 370 |
print("=" * 60)
|
| 371 |
print("MediGuard AI RAG-Helper - LLM Configuration")
|
| 372 |
print("=" * 60)
|
| 373 |
print(f"Provider: {self.provider.upper()}")
|
| 374 |
-
|
| 375 |
if self.provider == "groq":
|
| 376 |
-
print(
|
| 377 |
elif self.provider == "gemini":
|
| 378 |
-
print(
|
| 379 |
else:
|
| 380 |
-
print(
|
| 381 |
-
|
| 382 |
-
print(
|
| 383 |
print("=" * 60)
|
| 384 |
|
| 385 |
|
|
@@ -387,7 +388,7 @@ class LLMConfig:
|
|
| 387 |
llm_config = LLMConfig()
|
| 388 |
|
| 389 |
|
| 390 |
-
def get_synthesizer(model_name:
|
| 391 |
"""Module-level convenience: get a synthesizer LLM instance."""
|
| 392 |
return llm_config.get_synthesizer(model_name)
|
| 393 |
|
|
@@ -395,7 +396,7 @@ def get_synthesizer(model_name: Optional[str] = None):
|
|
| 395 |
def check_api_connection():
|
| 396 |
"""Verify API connection and keys are configured"""
|
| 397 |
provider = DEFAULT_LLM_PROVIDER
|
| 398 |
-
|
| 399 |
try:
|
| 400 |
if provider == "groq":
|
| 401 |
api_key = os.getenv("GROQ_API_KEY")
|
|
@@ -404,13 +405,13 @@ def check_api_connection():
|
|
| 404 |
print("\n Get your FREE API key at:")
|
| 405 |
print(" https://console.groq.com/keys")
|
| 406 |
return False
|
| 407 |
-
|
| 408 |
# Test connection
|
| 409 |
test_model = get_chat_model("groq")
|
| 410 |
response = test_model.invoke("Say 'OK' in one word")
|
| 411 |
print("OK: Groq API connection successful")
|
| 412 |
return True
|
| 413 |
-
|
| 414 |
elif provider == "gemini":
|
| 415 |
api_key = os.getenv("GOOGLE_API_KEY")
|
| 416 |
if not api_key:
|
|
@@ -418,12 +419,12 @@ def check_api_connection():
|
|
| 418 |
print("\n Get your FREE API key at:")
|
| 419 |
print(" https://aistudio.google.com/app/apikey")
|
| 420 |
return False
|
| 421 |
-
|
| 422 |
test_model = get_chat_model("gemini")
|
| 423 |
response = test_model.invoke("Say 'OK' in one word")
|
| 424 |
print("OK: Google Gemini API connection successful")
|
| 425 |
return True
|
| 426 |
-
|
| 427 |
else:
|
| 428 |
try:
|
| 429 |
from langchain_ollama import ChatOllama
|
|
@@ -433,7 +434,7 @@ def check_api_connection():
|
|
| 433 |
response = test_model.invoke("Hello")
|
| 434 |
print("OK: Ollama connection successful")
|
| 435 |
return True
|
| 436 |
-
|
| 437 |
except Exception as e:
|
| 438 |
print(f"ERROR: Connection failed: {e}")
|
| 439 |
return False
|
|
|
|
| 14 |
|
| 15 |
import os
|
| 16 |
import threading
|
| 17 |
+
from typing import Literal
|
| 18 |
+
|
| 19 |
from dotenv import load_dotenv
|
| 20 |
|
| 21 |
# Load environment variables
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def get_chat_model(
|
| 68 |
+
provider: Literal["groq", "gemini", "ollama"] | None = None,
|
| 69 |
+
model: str | None = None,
|
| 70 |
temperature: float = 0.0,
|
| 71 |
json_mode: bool = False
|
| 72 |
):
|
|
|
|
| 84 |
"""
|
| 85 |
# Use dynamic lookup to get current provider from environment
|
| 86 |
provider = provider or get_default_llm_provider()
|
| 87 |
+
|
| 88 |
if provider == "groq":
|
| 89 |
from langchain_groq import ChatGroq
|
| 90 |
+
|
| 91 |
api_key = get_groq_api_key()
|
| 92 |
if not api_key:
|
| 93 |
raise ValueError(
|
| 94 |
"GROQ_API_KEY not found in environment.\n"
|
| 95 |
"Get your FREE API key at: https://console.groq.com/keys"
|
| 96 |
)
|
| 97 |
+
|
| 98 |
# Use model from environment or default
|
| 99 |
model = model or get_groq_model()
|
| 100 |
+
|
| 101 |
return ChatGroq(
|
| 102 |
model=model,
|
| 103 |
temperature=temperature,
|
| 104 |
api_key=api_key,
|
| 105 |
model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {}
|
| 106 |
)
|
| 107 |
+
|
| 108 |
elif provider == "gemini":
|
| 109 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 110 |
+
|
| 111 |
api_key = get_google_api_key()
|
| 112 |
if not api_key:
|
| 113 |
raise ValueError(
|
| 114 |
"GOOGLE_API_KEY not found in environment.\n"
|
| 115 |
"Get your FREE API key at: https://aistudio.google.com/app/apikey"
|
| 116 |
)
|
| 117 |
+
|
| 118 |
# Use model from environment or default
|
| 119 |
model = model or get_gemini_model()
|
| 120 |
+
|
| 121 |
return ChatGoogleGenerativeAI(
|
| 122 |
model=model,
|
| 123 |
temperature=temperature,
|
| 124 |
google_api_key=api_key,
|
| 125 |
convert_system_message_to_human=True
|
| 126 |
)
|
| 127 |
+
|
| 128 |
elif provider == "ollama":
|
| 129 |
try:
|
| 130 |
from langchain_ollama import ChatOllama
|
| 131 |
except ImportError:
|
| 132 |
from langchain_community.chat_models import ChatOllama
|
| 133 |
+
|
| 134 |
model = model or "llama3.1:8b"
|
| 135 |
+
|
| 136 |
return ChatOllama(
|
| 137 |
model=model,
|
| 138 |
temperature=temperature,
|
| 139 |
format='json' if json_mode else None
|
| 140 |
)
|
| 141 |
+
|
| 142 |
else:
|
| 143 |
raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
|
| 144 |
|
|
|
|
| 148 |
return _get_env_with_fallback("EMBEDDING_PROVIDER", "EMBEDDING__PROVIDER", "huggingface")
|
| 149 |
|
| 150 |
|
| 151 |
+
def get_embedding_model(provider: Literal["jina", "google", "huggingface", "ollama"] | None = None):
|
| 152 |
"""
|
| 153 |
Get embedding model for vector search.
|
| 154 |
|
|
|
|
| 163 |
which has automatic fallback chain: Jina → Google → HuggingFace.
|
| 164 |
"""
|
| 165 |
provider = provider or get_embedding_provider()
|
| 166 |
+
|
| 167 |
if provider == "jina":
|
| 168 |
# Try Jina AI embeddings first (high quality, 1024d)
|
| 169 |
jina_key = _get_env_with_fallback("JINA_API_KEY", "EMBEDDING__JINA_API_KEY", "")
|
|
|
|
| 179 |
else:
|
| 180 |
print("WARN: JINA_API_KEY not found. Falling back to Google embeddings.")
|
| 181 |
return get_embedding_model("google")
|
| 182 |
+
|
| 183 |
elif provider == "google":
|
| 184 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 185 |
+
|
| 186 |
api_key = get_google_api_key()
|
| 187 |
if not api_key:
|
| 188 |
print("WARN: GOOGLE_API_KEY not found. Falling back to HuggingFace embeddings.")
|
| 189 |
return get_embedding_model("huggingface")
|
| 190 |
+
|
| 191 |
try:
|
| 192 |
return GoogleGenerativeAIEmbeddings(
|
| 193 |
model="models/text-embedding-004",
|
|
|
|
| 197 |
print(f"WARN: Google embeddings failed: {e}")
|
| 198 |
print("INFO: Falling back to HuggingFace embeddings...")
|
| 199 |
return get_embedding_model("huggingface")
|
| 200 |
+
|
| 201 |
elif provider == "huggingface":
|
| 202 |
try:
|
| 203 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 204 |
except ImportError:
|
| 205 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 206 |
+
|
| 207 |
return HuggingFaceEmbeddings(
|
| 208 |
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
| 209 |
)
|
| 210 |
+
|
| 211 |
elif provider == "ollama":
|
| 212 |
try:
|
| 213 |
from langchain_ollama import OllamaEmbeddings
|
| 214 |
except ImportError:
|
| 215 |
from langchain_community.embeddings import OllamaEmbeddings
|
| 216 |
+
|
| 217 |
return OllamaEmbeddings(model="nomic-embed-text")
|
| 218 |
+
|
| 219 |
else:
|
| 220 |
raise ValueError(f"Unknown embedding provider: {provider}")
|
| 221 |
|
| 222 |
|
| 223 |
class LLMConfig:
|
| 224 |
"""Central configuration for all LLM models"""
|
| 225 |
+
|
| 226 |
+
def __init__(self, provider: str | None = None, lazy: bool = True):
|
| 227 |
"""
|
| 228 |
Initialize all model clients.
|
| 229 |
|
|
|
|
| 237 |
self._initialized = False
|
| 238 |
self._initialized_provider = None # Track which provider was initialized
|
| 239 |
self._lock = threading.Lock()
|
| 240 |
+
|
| 241 |
# Lazy-initialized model instances
|
| 242 |
self._planner = None
|
| 243 |
self._analyzer = None
|
|
|
|
| 246 |
self._synthesizer_8b = None
|
| 247 |
self._director = None
|
| 248 |
self._embedding_model = None
|
| 249 |
+
|
| 250 |
if not lazy:
|
| 251 |
self._initialize_models()
|
| 252 |
+
|
| 253 |
@property
|
| 254 |
def provider(self) -> str:
|
| 255 |
"""Get current provider (dynamic lookup if not explicitly set)."""
|
| 256 |
return self._explicit_provider or get_default_llm_provider()
|
| 257 |
+
|
| 258 |
def _check_provider_change(self):
|
| 259 |
"""Check if provider changed and reinitialize if needed."""
|
| 260 |
current = self.provider
|
|
|
|
| 267 |
self._synthesizer_7b = None
|
| 268 |
self._synthesizer_8b = None
|
| 269 |
self._director = None
|
| 270 |
+
|
| 271 |
def _initialize_models(self):
|
| 272 |
"""Initialize all model clients (called on first use if lazy)"""
|
| 273 |
self._check_provider_change()
|
| 274 |
+
|
| 275 |
if self._initialized:
|
| 276 |
return
|
| 277 |
+
|
| 278 |
with self._lock:
|
| 279 |
# Double-checked locking
|
| 280 |
if self._initialized:
|
| 281 |
return
|
| 282 |
+
|
| 283 |
print(f"Initializing LLM models with provider: {self.provider.upper()}")
|
| 284 |
+
|
| 285 |
# Fast model for structured tasks (planning, analysis)
|
| 286 |
self._planner = get_chat_model(
|
| 287 |
provider=self.provider,
|
| 288 |
temperature=0.0,
|
| 289 |
json_mode=True
|
| 290 |
)
|
| 291 |
+
|
| 292 |
# Fast model for biomarker analysis and quick tasks
|
| 293 |
self._analyzer = get_chat_model(
|
| 294 |
provider=self.provider,
|
| 295 |
temperature=0.0
|
| 296 |
)
|
| 297 |
+
|
| 298 |
# Medium model for RAG retrieval and explanation
|
| 299 |
self._explainer = get_chat_model(
|
| 300 |
provider=self.provider,
|
| 301 |
temperature=0.2
|
| 302 |
)
|
| 303 |
+
|
| 304 |
# Configurable synthesizers
|
| 305 |
self._synthesizer_7b = get_chat_model(
|
| 306 |
provider=self.provider,
|
| 307 |
temperature=0.2
|
| 308 |
)
|
| 309 |
+
|
| 310 |
self._synthesizer_8b = get_chat_model(
|
| 311 |
provider=self.provider,
|
| 312 |
temperature=0.2
|
| 313 |
)
|
| 314 |
+
|
| 315 |
# Director for Outer Loop
|
| 316 |
self._director = get_chat_model(
|
| 317 |
provider=self.provider,
|
| 318 |
temperature=0.0,
|
| 319 |
json_mode=True
|
| 320 |
)
|
| 321 |
+
|
| 322 |
+
# Embedding model for RAG
|
| 323 |
self._embedding_model = get_embedding_model()
|
| 324 |
+
|
| 325 |
self._initialized = True
|
| 326 |
self._initialized_provider = self.provider
|
| 327 |
+
|
| 328 |
@property
|
| 329 |
def planner(self):
|
| 330 |
self._initialize_models()
|
| 331 |
return self._planner
|
| 332 |
+
|
| 333 |
@property
|
| 334 |
def analyzer(self):
|
| 335 |
self._initialize_models()
|
| 336 |
return self._analyzer
|
| 337 |
+
|
| 338 |
@property
|
| 339 |
def explainer(self):
|
| 340 |
self._initialize_models()
|
| 341 |
return self._explainer
|
| 342 |
+
|
| 343 |
@property
|
| 344 |
def synthesizer_7b(self):
|
| 345 |
self._initialize_models()
|
| 346 |
return self._synthesizer_7b
|
| 347 |
+
|
| 348 |
@property
|
| 349 |
def synthesizer_8b(self):
|
| 350 |
self._initialize_models()
|
| 351 |
return self._synthesizer_8b
|
| 352 |
+
|
| 353 |
@property
|
| 354 |
def director(self):
|
| 355 |
self._initialize_models()
|
| 356 |
return self._director
|
| 357 |
+
|
| 358 |
@property
|
| 359 |
def embedding_model(self):
|
| 360 |
self._initialize_models()
|
| 361 |
return self._embedding_model
|
| 362 |
+
|
| 363 |
+
def get_synthesizer(self, model_name: str | None = None):
|
| 364 |
"""Get synthesizer model (for backward compatibility)"""
|
| 365 |
if model_name:
|
| 366 |
return get_chat_model(provider=self.provider, model=model_name, temperature=0.2)
|
| 367 |
return self.synthesizer_8b
|
| 368 |
+
|
| 369 |
def print_config(self):
|
| 370 |
"""Print current LLM configuration"""
|
| 371 |
print("=" * 60)
|
| 372 |
print("MediGuard AI RAG-Helper - LLM Configuration")
|
| 373 |
print("=" * 60)
|
| 374 |
print(f"Provider: {self.provider.upper()}")
|
| 375 |
+
|
| 376 |
if self.provider == "groq":
|
| 377 |
+
print("Model: llama-3.3-70b-versatile (FREE)")
|
| 378 |
elif self.provider == "gemini":
|
| 379 |
+
print("Model: gemini-2.0-flash (FREE)")
|
| 380 |
else:
|
| 381 |
+
print("Model: llama3.1:8b (local)")
|
| 382 |
+
|
| 383 |
+
print("Embeddings: Google Gemini (FREE)")
|
| 384 |
print("=" * 60)
|
| 385 |
|
| 386 |
|
|
|
|
| 388 |
llm_config = LLMConfig()
|
| 389 |
|
| 390 |
|
| 391 |
+
def get_synthesizer(model_name: str | None = None):
|
| 392 |
"""Module-level convenience: get a synthesizer LLM instance."""
|
| 393 |
return llm_config.get_synthesizer(model_name)
|
| 394 |
|
|
|
|
| 396 |
def check_api_connection():
|
| 397 |
"""Verify API connection and keys are configured"""
|
| 398 |
provider = DEFAULT_LLM_PROVIDER
|
| 399 |
+
|
| 400 |
try:
|
| 401 |
if provider == "groq":
|
| 402 |
api_key = os.getenv("GROQ_API_KEY")
|
|
|
|
| 405 |
print("\n Get your FREE API key at:")
|
| 406 |
print(" https://console.groq.com/keys")
|
| 407 |
return False
|
| 408 |
+
|
| 409 |
# Test connection
|
| 410 |
test_model = get_chat_model("groq")
|
| 411 |
response = test_model.invoke("Say 'OK' in one word")
|
| 412 |
print("OK: Groq API connection successful")
|
| 413 |
return True
|
| 414 |
+
|
| 415 |
elif provider == "gemini":
|
| 416 |
api_key = os.getenv("GOOGLE_API_KEY")
|
| 417 |
if not api_key:
|
|
|
|
| 419 |
print("\n Get your FREE API key at:")
|
| 420 |
print(" https://aistudio.google.com/app/apikey")
|
| 421 |
return False
|
| 422 |
+
|
| 423 |
test_model = get_chat_model("gemini")
|
| 424 |
response = test_model.invoke("Say 'OK' in one word")
|
| 425 |
print("OK: Google Gemini API connection successful")
|
| 426 |
return True
|
| 427 |
+
|
| 428 |
else:
|
| 429 |
try:
|
| 430 |
from langchain_ollama import ChatOllama
|
|
|
|
| 434 |
response = test_model.invoke("Hello")
|
| 435 |
print("OK: Ollama connection successful")
|
| 436 |
return True
|
| 437 |
+
|
| 438 |
except Exception as e:
|
| 439 |
print(f"ERROR: Connection failed: {e}")
|
| 440 |
return False
|
src/main.py
CHANGED
|
@@ -13,7 +13,7 @@ import logging
|
|
| 13 |
import os
|
| 14 |
import time
|
| 15 |
from contextlib import asynccontextmanager
|
| 16 |
-
from datetime import
|
| 17 |
|
| 18 |
from fastapi import FastAPI, Request, status
|
| 19 |
from fastapi.exceptions import RequestValidationError
|
|
@@ -49,7 +49,9 @@ async def lifespan(app: FastAPI):
|
|
| 49 |
# --- OpenSearch ---
|
| 50 |
try:
|
| 51 |
from src.services.opensearch.client import make_opensearch_client
|
|
|
|
| 52 |
app.state.opensearch_client = make_opensearch_client()
|
|
|
|
| 53 |
logger.info("OpenSearch client ready")
|
| 54 |
except Exception as exc:
|
| 55 |
logger.warning("OpenSearch unavailable: %s", exc)
|
|
@@ -59,7 +61,7 @@ async def lifespan(app: FastAPI):
|
|
| 59 |
try:
|
| 60 |
from src.services.embeddings.service import make_embedding_service
|
| 61 |
app.state.embedding_service = make_embedding_service()
|
| 62 |
-
logger.info("Embedding service ready (provider=%s)", app.state.embedding_service.
|
| 63 |
except Exception as exc:
|
| 64 |
logger.warning("Embedding service unavailable: %s", exc)
|
| 65 |
app.state.embedding_service = None
|
|
@@ -93,11 +95,11 @@ async def lifespan(app: FastAPI):
|
|
| 93 |
|
| 94 |
# --- Agentic RAG service ---
|
| 95 |
try:
|
|
|
|
| 96 |
from src.services.agents.agentic_rag import AgenticRAGService
|
| 97 |
from src.services.agents.context import AgenticContext
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
llm = app.state.ollama_client.get_langchain_model()
|
| 101 |
ctx = AgenticContext(
|
| 102 |
llm=llm,
|
| 103 |
embedding_service=app.state.embedding_service,
|
|
@@ -109,17 +111,16 @@ async def lifespan(app: FastAPI):
|
|
| 109 |
logger.info("Agentic RAG service ready")
|
| 110 |
else:
|
| 111 |
app.state.rag_service = None
|
| 112 |
-
logger.warning("Agentic RAG service skipped — missing backing services")
|
| 113 |
except Exception as exc:
|
| 114 |
logger.warning("Agentic RAG service failed: %s", exc)
|
| 115 |
app.state.rag_service = None
|
| 116 |
|
| 117 |
# --- Legacy RagBot service (backward-compatible /analyze) ---
|
| 118 |
try:
|
| 119 |
-
from
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
app.state.ragbot_service = ragbot
|
| 123 |
logger.info("RagBot service ready (ClinicalInsightGuild)")
|
| 124 |
except Exception as exc:
|
| 125 |
logger.warning("RagBot service unavailable: %s", exc)
|
|
@@ -127,17 +128,13 @@ async def lifespan(app: FastAPI):
|
|
| 127 |
|
| 128 |
# --- Extraction service (for natural language input) ---
|
| 129 |
try:
|
|
|
|
| 130 |
from src.services.extraction.service import make_extraction_service
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
llm = getattr(app.state.rag_service, '_context', {})
|
| 137 |
-
if hasattr(llm, 'llm'):
|
| 138 |
-
llm = llm.llm
|
| 139 |
-
else:
|
| 140 |
-
llm = None
|
| 141 |
# If no LLM available, extraction will use regex fallback
|
| 142 |
app.state.extraction_service = make_extraction_service(llm=llm)
|
| 143 |
logger.info("Extraction service ready")
|
|
@@ -196,7 +193,7 @@ def create_app() -> FastAPI:
|
|
| 196 |
"error_code": "VALIDATION_ERROR",
|
| 197 |
"message": "Request validation failed",
|
| 198 |
"details": exc.errors(),
|
| 199 |
-
"timestamp": datetime.now(
|
| 200 |
},
|
| 201 |
)
|
| 202 |
|
|
@@ -209,12 +206,12 @@ def create_app() -> FastAPI:
|
|
| 209 |
"status": "error",
|
| 210 |
"error_code": "INTERNAL_SERVER_ERROR",
|
| 211 |
"message": "An unexpected error occurred. Please try again later.",
|
| 212 |
-
"timestamp": datetime.now(
|
| 213 |
},
|
| 214 |
)
|
| 215 |
|
| 216 |
# --- Routers ---
|
| 217 |
-
from src.routers import
|
| 218 |
|
| 219 |
app.include_router(health.router)
|
| 220 |
app.include_router(analyze.router)
|
|
|
|
| 13 |
import os
|
| 14 |
import time
|
| 15 |
from contextlib import asynccontextmanager
|
| 16 |
+
from datetime import UTC, datetime
|
| 17 |
|
| 18 |
from fastapi import FastAPI, Request, status
|
| 19 |
from fastapi.exceptions import RequestValidationError
|
|
|
|
| 49 |
# --- OpenSearch ---
|
| 50 |
try:
|
| 51 |
from src.services.opensearch.client import make_opensearch_client
|
| 52 |
+
from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING
|
| 53 |
app.state.opensearch_client = make_opensearch_client()
|
| 54 |
+
app.state.opensearch_client.ensure_index(MEDICAL_CHUNKS_MAPPING)
|
| 55 |
logger.info("OpenSearch client ready")
|
| 56 |
except Exception as exc:
|
| 57 |
logger.warning("OpenSearch unavailable: %s", exc)
|
|
|
|
| 61 |
try:
|
| 62 |
from src.services.embeddings.service import make_embedding_service
|
| 63 |
app.state.embedding_service = make_embedding_service()
|
| 64 |
+
logger.info("Embedding service ready (provider=%s)", app.state.embedding_service.provider_name)
|
| 65 |
except Exception as exc:
|
| 66 |
logger.warning("Embedding service unavailable: %s", exc)
|
| 67 |
app.state.embedding_service = None
|
|
|
|
| 95 |
|
| 96 |
# --- Agentic RAG service ---
|
| 97 |
try:
|
| 98 |
+
from src.llm_config import get_llm
|
| 99 |
from src.services.agents.agentic_rag import AgenticRAGService
|
| 100 |
from src.services.agents.context import AgenticContext
|
| 101 |
+
if app.state.opensearch_client and app.state.embedding_service:
|
| 102 |
+
llm = get_llm()
|
|
|
|
| 103 |
ctx = AgenticContext(
|
| 104 |
llm=llm,
|
| 105 |
embedding_service=app.state.embedding_service,
|
|
|
|
| 111 |
logger.info("Agentic RAG service ready")
|
| 112 |
else:
|
| 113 |
app.state.rag_service = None
|
| 114 |
+
logger.warning("Agentic RAG service skipped — missing backing services (OpenSearch or Embedding)")
|
| 115 |
except Exception as exc:
|
| 116 |
logger.warning("Agentic RAG service failed: %s", exc)
|
| 117 |
app.state.rag_service = None
|
| 118 |
|
| 119 |
# --- Legacy RagBot service (backward-compatible /analyze) ---
|
| 120 |
try:
|
| 121 |
+
from src.workflow import create_guild
|
| 122 |
+
guild = create_guild()
|
| 123 |
+
app.state.ragbot_service = guild
|
|
|
|
| 124 |
logger.info("RagBot service ready (ClinicalInsightGuild)")
|
| 125 |
except Exception as exc:
|
| 126 |
logger.warning("RagBot service unavailable: %s", exc)
|
|
|
|
| 128 |
|
| 129 |
# --- Extraction service (for natural language input) ---
|
| 130 |
try:
|
| 131 |
+
from src.llm_config import get_llm
|
| 132 |
from src.services.extraction.service import make_extraction_service
|
| 133 |
+
try:
|
| 134 |
+
llm = get_llm()
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.warning("Failed to get LLM for extraction, will use fallback: %s", e)
|
| 137 |
+
llm = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
# If no LLM available, extraction will use regex fallback
|
| 139 |
app.state.extraction_service = make_extraction_service(llm=llm)
|
| 140 |
logger.info("Extraction service ready")
|
|
|
|
| 193 |
"error_code": "VALIDATION_ERROR",
|
| 194 |
"message": "Request validation failed",
|
| 195 |
"details": exc.errors(),
|
| 196 |
+
"timestamp": datetime.now(UTC).isoformat(),
|
| 197 |
},
|
| 198 |
)
|
| 199 |
|
|
|
|
| 206 |
"status": "error",
|
| 207 |
"error_code": "INTERNAL_SERVER_ERROR",
|
| 208 |
"message": "An unexpected error occurred. Please try again later.",
|
| 209 |
+
"timestamp": datetime.now(UTC).isoformat(),
|
| 210 |
},
|
| 211 |
)
|
| 212 |
|
| 213 |
# --- Routers ---
|
| 214 |
+
from src.routers import analyze, ask, health, search
|
| 215 |
|
| 216 |
app.include_router(health.router)
|
| 217 |
app.include_router(analyze.router)
|
src/middlewares.py
CHANGED
|
@@ -12,8 +12,9 @@ import json
|
|
| 12 |
import logging
|
| 13 |
import time
|
| 14 |
import uuid
|
| 15 |
-
from
|
| 16 |
-
from
|
|
|
|
| 17 |
|
| 18 |
from fastapi import Request, Response
|
| 19 |
from starlette.middleware.base import BaseHTTPMiddleware
|
|
@@ -74,35 +75,35 @@ class HIPAAAuditMiddleware(BaseHTTPMiddleware):
|
|
| 74 |
|
| 75 |
Audit logs are structured JSON for easy SIEM integration.
|
| 76 |
"""
|
| 77 |
-
|
| 78 |
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 79 |
# Generate request ID
|
| 80 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 81 |
request.state.request_id = request_id
|
| 82 |
-
|
| 83 |
# Start timing
|
| 84 |
start_time = time.time()
|
| 85 |
-
|
| 86 |
# Extract metadata safely
|
| 87 |
path = request.url.path
|
| 88 |
method = request.method
|
| 89 |
client_ip = request.client.host if request.client else "unknown"
|
| 90 |
user_agent = request.headers.get("user-agent", "unknown")[:100]
|
| 91 |
-
|
| 92 |
# Check if this endpoint needs audit logging
|
| 93 |
needs_audit = any(path.startswith(ep) for ep in AUDITABLE_ENDPOINTS)
|
| 94 |
-
|
| 95 |
# Pre-request audit entry
|
| 96 |
audit_entry: dict[str, Any] = {
|
| 97 |
"event": "request_start",
|
| 98 |
-
"timestamp": datetime.now(
|
| 99 |
"request_id": request_id,
|
| 100 |
"method": method,
|
| 101 |
"path": path,
|
| 102 |
"client_ip_hash": _hash_sensitive(client_ip),
|
| 103 |
"user_agent_hash": _hash_sensitive(user_agent),
|
| 104 |
}
|
| 105 |
-
|
| 106 |
# Try to read request body for POST requests (without logging PHI)
|
| 107 |
if needs_audit and method == "POST":
|
| 108 |
try:
|
|
@@ -116,35 +117,35 @@ class HIPAAAuditMiddleware(BaseHTTPMiddleware):
|
|
| 116 |
# Log presence of biomarkers without values
|
| 117 |
if "biomarkers" in body_dict:
|
| 118 |
audit_entry["biomarker_count"] = len(body_dict["biomarkers"]) if isinstance(body_dict["biomarkers"], dict) else 1
|
| 119 |
-
except Exception:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
if needs_audit:
|
| 123 |
logger.info("AUDIT_REQUEST: %s", json.dumps(audit_entry))
|
| 124 |
-
|
| 125 |
# Process request
|
| 126 |
response: Response = await call_next(request)
|
| 127 |
-
|
| 128 |
# Post-request audit
|
| 129 |
elapsed_ms = (time.time() - start_time) * 1000
|
| 130 |
-
|
| 131 |
completion_entry = {
|
| 132 |
"event": "request_complete",
|
| 133 |
-
"timestamp": datetime.now(
|
| 134 |
"request_id": request_id,
|
| 135 |
"method": method,
|
| 136 |
"path": path,
|
| 137 |
"status_code": response.status_code,
|
| 138 |
"elapsed_ms": round(elapsed_ms, 2),
|
| 139 |
}
|
| 140 |
-
|
| 141 |
if needs_audit:
|
| 142 |
logger.info("AUDIT_COMPLETE: %s", json.dumps(completion_entry))
|
| 143 |
-
|
| 144 |
# Add request ID to response headers
|
| 145 |
response.headers["X-Request-ID"] = request_id
|
| 146 |
response.headers["X-Response-Time"] = f"{elapsed_ms:.2f}ms"
|
| 147 |
-
|
| 148 |
return response
|
| 149 |
|
| 150 |
|
|
@@ -152,10 +153,10 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
| 152 |
"""
|
| 153 |
Add security headers for HIPAA compliance.
|
| 154 |
"""
|
| 155 |
-
|
| 156 |
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 157 |
response: Response = await call_next(request)
|
| 158 |
-
|
| 159 |
# Security headers
|
| 160 |
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 161 |
response.headers["X-Frame-Options"] = "DENY"
|
|
@@ -163,9 +164,9 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
| 163 |
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
| 164 |
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
| 165 |
response.headers["Pragma"] = "no-cache"
|
| 166 |
-
|
| 167 |
# Medical data should never be cached
|
| 168 |
if any(ep in request.url.path for ep in AUDITABLE_ENDPOINTS):
|
| 169 |
response.headers["Cache-Control"] = "no-store, private"
|
| 170 |
-
|
| 171 |
return response
|
|
|
|
| 12 |
import logging
|
| 13 |
import time
|
| 14 |
import uuid
|
| 15 |
+
from collections.abc import Callable
|
| 16 |
+
from datetime import UTC, datetime
|
| 17 |
+
from typing import Any
|
| 18 |
|
| 19 |
from fastapi import Request, Response
|
| 20 |
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
| 75 |
|
| 76 |
Audit logs are structured JSON for easy SIEM integration.
|
| 77 |
"""
|
| 78 |
+
|
| 79 |
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 80 |
# Generate request ID
|
| 81 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 82 |
request.state.request_id = request_id
|
| 83 |
+
|
| 84 |
# Start timing
|
| 85 |
start_time = time.time()
|
| 86 |
+
|
| 87 |
# Extract metadata safely
|
| 88 |
path = request.url.path
|
| 89 |
method = request.method
|
| 90 |
client_ip = request.client.host if request.client else "unknown"
|
| 91 |
user_agent = request.headers.get("user-agent", "unknown")[:100]
|
| 92 |
+
|
| 93 |
# Check if this endpoint needs audit logging
|
| 94 |
needs_audit = any(path.startswith(ep) for ep in AUDITABLE_ENDPOINTS)
|
| 95 |
+
|
| 96 |
# Pre-request audit entry
|
| 97 |
audit_entry: dict[str, Any] = {
|
| 98 |
"event": "request_start",
|
| 99 |
+
"timestamp": datetime.now(UTC).isoformat(),
|
| 100 |
"request_id": request_id,
|
| 101 |
"method": method,
|
| 102 |
"path": path,
|
| 103 |
"client_ip_hash": _hash_sensitive(client_ip),
|
| 104 |
"user_agent_hash": _hash_sensitive(user_agent),
|
| 105 |
}
|
| 106 |
+
|
| 107 |
# Try to read request body for POST requests (without logging PHI)
|
| 108 |
if needs_audit and method == "POST":
|
| 109 |
try:
|
|
|
|
| 117 |
# Log presence of biomarkers without values
|
| 118 |
if "biomarkers" in body_dict:
|
| 119 |
audit_entry["biomarker_count"] = len(body_dict["biomarkers"]) if isinstance(body_dict["biomarkers"], dict) else 1
|
| 120 |
+
except Exception as exc:
|
| 121 |
+
logger.debug("Failed to audit POST body: %s", exc)
|
| 122 |
+
|
| 123 |
if needs_audit:
|
| 124 |
logger.info("AUDIT_REQUEST: %s", json.dumps(audit_entry))
|
| 125 |
+
|
| 126 |
# Process request
|
| 127 |
response: Response = await call_next(request)
|
| 128 |
+
|
| 129 |
# Post-request audit
|
| 130 |
elapsed_ms = (time.time() - start_time) * 1000
|
| 131 |
+
|
| 132 |
completion_entry = {
|
| 133 |
"event": "request_complete",
|
| 134 |
+
"timestamp": datetime.now(UTC).isoformat(),
|
| 135 |
"request_id": request_id,
|
| 136 |
"method": method,
|
| 137 |
"path": path,
|
| 138 |
"status_code": response.status_code,
|
| 139 |
"elapsed_ms": round(elapsed_ms, 2),
|
| 140 |
}
|
| 141 |
+
|
| 142 |
if needs_audit:
|
| 143 |
logger.info("AUDIT_COMPLETE: %s", json.dumps(completion_entry))
|
| 144 |
+
|
| 145 |
# Add request ID to response headers
|
| 146 |
response.headers["X-Request-ID"] = request_id
|
| 147 |
response.headers["X-Response-Time"] = f"{elapsed_ms:.2f}ms"
|
| 148 |
+
|
| 149 |
return response
|
| 150 |
|
| 151 |
|
|
|
|
| 153 |
"""
|
| 154 |
Add security headers for HIPAA compliance.
|
| 155 |
"""
|
| 156 |
+
|
| 157 |
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 158 |
response: Response = await call_next(request)
|
| 159 |
+
|
| 160 |
# Security headers
|
| 161 |
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 162 |
response.headers["X-Frame-Options"] = "DENY"
|
|
|
|
| 164 |
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
| 165 |
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
| 166 |
response.headers["Pragma"] = "no-cache"
|
| 167 |
+
|
| 168 |
# Medical data should never be cached
|
| 169 |
if any(ep in request.url.path for ep in AUDITABLE_ENDPOINTS):
|
| 170 |
response.headers["Cache-Control"] = "no-store, private"
|
| 171 |
+
|
| 172 |
return response
|
src/pdf_processor.py
CHANGED
|
@@ -6,13 +6,12 @@ PDF document processing and vector store creation
|
|
| 6 |
import os
|
| 7 |
import warnings
|
| 8 |
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
from langchain_community.vectorstores import FAISS
|
| 13 |
from langchain_core.documents import Document
|
| 14 |
-
from
|
| 15 |
-
import time
|
| 16 |
|
| 17 |
# Suppress noisy warnings
|
| 18 |
warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
|
|
@@ -22,12 +21,12 @@ os.environ.setdefault("HF_HUB_DISABLE_IMPLICIT_TOKEN", "1")
|
|
| 22 |
load_dotenv()
|
| 23 |
|
| 24 |
# Re-export for backward compatibility
|
| 25 |
-
from src.llm_config import get_embedding_model
|
| 26 |
|
| 27 |
|
| 28 |
class PDFProcessor:
|
| 29 |
"""Handles medical PDF ingestion and vector store creation"""
|
| 30 |
-
|
| 31 |
def __init__(
|
| 32 |
self,
|
| 33 |
pdf_directory: str = "data/medical_pdfs",
|
|
@@ -48,11 +47,11 @@ class PDFProcessor:
|
|
| 48 |
self.vector_store_path = Path(vector_store_path)
|
| 49 |
self.chunk_size = chunk_size
|
| 50 |
self.chunk_overlap = chunk_overlap
|
| 51 |
-
|
| 52 |
# Create directories if they don't exist
|
| 53 |
self.pdf_directory.mkdir(parents=True, exist_ok=True)
|
| 54 |
self.vector_store_path.mkdir(parents=True, exist_ok=True)
|
| 55 |
-
|
| 56 |
# Text splitter with medical context awareness
|
| 57 |
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 58 |
chunk_size=chunk_size,
|
|
@@ -60,8 +59,8 @@ class PDFProcessor:
|
|
| 60 |
separators=["\n\n", "\n", ". ", " ", ""],
|
| 61 |
length_function=len
|
| 62 |
)
|
| 63 |
-
|
| 64 |
-
def load_pdfs(self) ->
|
| 65 |
"""
|
| 66 |
Load all PDF documents from the configured directory.
|
| 67 |
|
|
@@ -69,40 +68,40 @@ class PDFProcessor:
|
|
| 69 |
List of Document objects with content and metadata
|
| 70 |
"""
|
| 71 |
print(f"Loading PDFs from: {self.pdf_directory}")
|
| 72 |
-
|
| 73 |
pdf_files = list(self.pdf_directory.glob("*.pdf"))
|
| 74 |
-
|
| 75 |
if not pdf_files:
|
| 76 |
print(f"WARN: No PDF files found in {self.pdf_directory}")
|
| 77 |
print("INFO: Please place medical PDFs in this directory")
|
| 78 |
return []
|
| 79 |
-
|
| 80 |
print(f"Found {len(pdf_files)} PDF file(s):")
|
| 81 |
for pdf in pdf_files:
|
| 82 |
print(f" - {pdf.name}")
|
| 83 |
-
|
| 84 |
documents = []
|
| 85 |
-
|
| 86 |
for pdf_path in pdf_files:
|
| 87 |
try:
|
| 88 |
loader = PyPDFLoader(str(pdf_path))
|
| 89 |
docs = loader.load()
|
| 90 |
-
|
| 91 |
# Add source filename to metadata
|
| 92 |
for doc in docs:
|
| 93 |
doc.metadata['source_file'] = pdf_path.name
|
| 94 |
doc.metadata['source_path'] = str(pdf_path)
|
| 95 |
-
|
| 96 |
documents.extend(docs)
|
| 97 |
print(f" OK: Loaded {len(docs)} pages from {pdf_path.name}")
|
| 98 |
-
|
| 99 |
except Exception as e:
|
| 100 |
print(f" ERROR: Error loading {pdf_path.name}: {e}")
|
| 101 |
-
|
| 102 |
print(f"\nTotal: {len(documents)} pages loaded from {len(pdf_files)} PDF(s)")
|
| 103 |
return documents
|
| 104 |
-
|
| 105 |
-
def chunk_documents(self, documents:
|
| 106 |
"""
|
| 107 |
Split documents into chunks for RAG retrieval.
|
| 108 |
|
|
@@ -113,25 +112,25 @@ class PDFProcessor:
|
|
| 113 |
List of chunked documents with preserved metadata
|
| 114 |
"""
|
| 115 |
print(f"\nChunking documents (size={self.chunk_size}, overlap={self.chunk_overlap})...")
|
| 116 |
-
|
| 117 |
chunks = self.text_splitter.split_documents(documents)
|
| 118 |
-
|
| 119 |
if not chunks:
|
| 120 |
print("WARN: No chunks generated from documents")
|
| 121 |
return chunks
|
| 122 |
-
|
| 123 |
# Add chunk index to metadata
|
| 124 |
for i, chunk in enumerate(chunks):
|
| 125 |
chunk.metadata['chunk_id'] = i
|
| 126 |
-
|
| 127 |
print(f"OK: Created {len(chunks)} chunks from {len(documents)} pages")
|
| 128 |
print(f" Average chunk size: {sum(len(c.page_content) for c in chunks) // len(chunks)} characters")
|
| 129 |
-
|
| 130 |
return chunks
|
| 131 |
-
|
| 132 |
def create_vector_store(
|
| 133 |
self,
|
| 134 |
-
chunks:
|
| 135 |
embedding_model,
|
| 136 |
store_name: str = "medical_knowledge"
|
| 137 |
) -> FAISS:
|
|
@@ -149,26 +148,26 @@ class PDFProcessor:
|
|
| 149 |
print(f"\nCreating vector store: {store_name}")
|
| 150 |
print(f"Generating embeddings for {len(chunks)} chunks...")
|
| 151 |
print("(This may take a few minutes...)")
|
| 152 |
-
|
| 153 |
# Create FAISS vector store
|
| 154 |
vector_store = FAISS.from_documents(
|
| 155 |
documents=chunks,
|
| 156 |
embedding=embedding_model
|
| 157 |
)
|
| 158 |
-
|
| 159 |
# Save to disk
|
| 160 |
save_path = self.vector_store_path / f"{store_name}.faiss"
|
| 161 |
vector_store.save_local(str(self.vector_store_path), index_name=store_name)
|
| 162 |
-
|
| 163 |
print(f"OK: Vector store created and saved to: {save_path}")
|
| 164 |
-
|
| 165 |
return vector_store
|
| 166 |
-
|
| 167 |
def load_vector_store(
|
| 168 |
self,
|
| 169 |
embedding_model,
|
| 170 |
store_name: str = "medical_knowledge"
|
| 171 |
-
) ->
|
| 172 |
"""
|
| 173 |
Load existing vector store from disk.
|
| 174 |
|
|
@@ -180,11 +179,11 @@ class PDFProcessor:
|
|
| 180 |
FAISS vector store or None if not found
|
| 181 |
"""
|
| 182 |
store_path = self.vector_store_path / f"{store_name}.faiss"
|
| 183 |
-
|
| 184 |
if not store_path.exists():
|
| 185 |
print(f"WARN: Vector store not found: {store_path}")
|
| 186 |
return None
|
| 187 |
-
|
| 188 |
try:
|
| 189 |
# SECURITY NOTE: allow_dangerous_deserialization=True uses pickle.
|
| 190 |
# Only load vector stores from trusted, locally-built sources.
|
|
@@ -197,11 +196,11 @@ class PDFProcessor:
|
|
| 197 |
)
|
| 198 |
print(f"OK: Loaded vector store from: {store_path}")
|
| 199 |
return vector_store
|
| 200 |
-
|
| 201 |
except Exception as e:
|
| 202 |
print(f"ERROR: Error loading vector store: {e}")
|
| 203 |
return None
|
| 204 |
-
|
| 205 |
def create_retrievers(
|
| 206 |
self,
|
| 207 |
embedding_model,
|
|
@@ -224,19 +223,19 @@ class PDFProcessor:
|
|
| 224 |
vector_store = self.load_vector_store(embedding_model, store_name)
|
| 225 |
else:
|
| 226 |
vector_store = None
|
| 227 |
-
|
| 228 |
# If not found, create new one
|
| 229 |
if vector_store is None:
|
| 230 |
print("\nBuilding new vector store from PDFs...")
|
| 231 |
documents = self.load_pdfs()
|
| 232 |
-
|
| 233 |
if not documents:
|
| 234 |
print("WARN: No documents to process. Please add PDF files.")
|
| 235 |
return {}
|
| 236 |
-
|
| 237 |
chunks = self.chunk_documents(documents)
|
| 238 |
vector_store = self.create_vector_store(chunks, embedding_model, store_name)
|
| 239 |
-
|
| 240 |
# Create specialized retrievers
|
| 241 |
retrievers = {
|
| 242 |
"disease_explainer": vector_store.as_retriever(
|
|
@@ -252,7 +251,7 @@ class PDFProcessor:
|
|
| 252 |
search_kwargs={"k": 5}
|
| 253 |
)
|
| 254 |
}
|
| 255 |
-
|
| 256 |
print(f"\nOK: Created {len(retrievers)} specialized retrievers")
|
| 257 |
return retrievers
|
| 258 |
|
|
@@ -272,28 +271,28 @@ def setup_knowledge_base(embedding_model=None, force_rebuild: bool = False, use_
|
|
| 272 |
print("=" * 60)
|
| 273 |
print("Setting up Medical Knowledge Base")
|
| 274 |
print("=" * 60)
|
| 275 |
-
|
| 276 |
# Use configured embedding provider from environment
|
| 277 |
if use_configured_embeddings and embedding_model is None:
|
| 278 |
embedding_model = get_embedding_model()
|
| 279 |
print(" > Embeddings model loaded")
|
| 280 |
elif embedding_model is None:
|
| 281 |
raise ValueError("Must provide embedding_model or set use_configured_embeddings=True")
|
| 282 |
-
|
| 283 |
processor = PDFProcessor()
|
| 284 |
retrievers = processor.create_retrievers(
|
| 285 |
embedding_model,
|
| 286 |
store_name="medical_knowledge",
|
| 287 |
force_rebuild=force_rebuild
|
| 288 |
)
|
| 289 |
-
|
| 290 |
if retrievers:
|
| 291 |
print("\nOK: Knowledge base setup complete!")
|
| 292 |
else:
|
| 293 |
print("\nWARN: Knowledge base setup incomplete. Add PDFs and try again.")
|
| 294 |
-
|
| 295 |
print("=" * 60)
|
| 296 |
-
|
| 297 |
return retrievers
|
| 298 |
|
| 299 |
|
|
@@ -320,22 +319,22 @@ if __name__ == "__main__":
|
|
| 320 |
# Test PDF processing
|
| 321 |
import sys
|
| 322 |
from pathlib import Path
|
| 323 |
-
|
| 324 |
# Add parent directory to path for imports
|
| 325 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 326 |
-
|
| 327 |
print("\n" + "="*70)
|
| 328 |
print("MediGuard AI - PDF Knowledge Base Builder")
|
| 329 |
print("="*70)
|
| 330 |
print("\nUsing configured embedding provider from .env")
|
| 331 |
print(" EMBEDDING_PROVIDER options: google (default), huggingface, ollama")
|
| 332 |
print("="*70)
|
| 333 |
-
|
| 334 |
retrievers = setup_knowledge_base(
|
| 335 |
use_configured_embeddings=True, # Use configured provider
|
| 336 |
force_rebuild=False
|
| 337 |
)
|
| 338 |
-
|
| 339 |
if retrievers:
|
| 340 |
print("\nOK: PDF processing test successful!")
|
| 341 |
print(f"Available retrievers: {list(retrievers.keys())}")
|
|
|
|
| 6 |
import os
|
| 7 |
import warnings
|
| 8 |
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 12 |
from langchain_community.vectorstores import FAISS
|
| 13 |
from langchain_core.documents import Document
|
| 14 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
|
| 15 |
|
| 16 |
# Suppress noisy warnings
|
| 17 |
warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
|
|
|
|
| 21 |
load_dotenv()
|
| 22 |
|
| 23 |
# Re-export for backward compatibility
|
| 24 |
+
from src.llm_config import get_embedding_model
|
| 25 |
|
| 26 |
|
| 27 |
class PDFProcessor:
|
| 28 |
"""Handles medical PDF ingestion and vector store creation"""
|
| 29 |
+
|
| 30 |
def __init__(
|
| 31 |
self,
|
| 32 |
pdf_directory: str = "data/medical_pdfs",
|
|
|
|
| 47 |
self.vector_store_path = Path(vector_store_path)
|
| 48 |
self.chunk_size = chunk_size
|
| 49 |
self.chunk_overlap = chunk_overlap
|
| 50 |
+
|
| 51 |
# Create directories if they don't exist
|
| 52 |
self.pdf_directory.mkdir(parents=True, exist_ok=True)
|
| 53 |
self.vector_store_path.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
|
| 55 |
# Text splitter with medical context awareness
|
| 56 |
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 57 |
chunk_size=chunk_size,
|
|
|
|
| 59 |
separators=["\n\n", "\n", ". ", " ", ""],
|
| 60 |
length_function=len
|
| 61 |
)
|
| 62 |
+
|
| 63 |
+
def load_pdfs(self) -> list[Document]:
|
| 64 |
"""
|
| 65 |
Load all PDF documents from the configured directory.
|
| 66 |
|
|
|
|
| 68 |
List of Document objects with content and metadata
|
| 69 |
"""
|
| 70 |
print(f"Loading PDFs from: {self.pdf_directory}")
|
| 71 |
+
|
| 72 |
pdf_files = list(self.pdf_directory.glob("*.pdf"))
|
| 73 |
+
|
| 74 |
if not pdf_files:
|
| 75 |
print(f"WARN: No PDF files found in {self.pdf_directory}")
|
| 76 |
print("INFO: Please place medical PDFs in this directory")
|
| 77 |
return []
|
| 78 |
+
|
| 79 |
print(f"Found {len(pdf_files)} PDF file(s):")
|
| 80 |
for pdf in pdf_files:
|
| 81 |
print(f" - {pdf.name}")
|
| 82 |
+
|
| 83 |
documents = []
|
| 84 |
+
|
| 85 |
for pdf_path in pdf_files:
|
| 86 |
try:
|
| 87 |
loader = PyPDFLoader(str(pdf_path))
|
| 88 |
docs = loader.load()
|
| 89 |
+
|
| 90 |
# Add source filename to metadata
|
| 91 |
for doc in docs:
|
| 92 |
doc.metadata['source_file'] = pdf_path.name
|
| 93 |
doc.metadata['source_path'] = str(pdf_path)
|
| 94 |
+
|
| 95 |
documents.extend(docs)
|
| 96 |
print(f" OK: Loaded {len(docs)} pages from {pdf_path.name}")
|
| 97 |
+
|
| 98 |
except Exception as e:
|
| 99 |
print(f" ERROR: Error loading {pdf_path.name}: {e}")
|
| 100 |
+
|
| 101 |
print(f"\nTotal: {len(documents)} pages loaded from {len(pdf_files)} PDF(s)")
|
| 102 |
return documents
|
| 103 |
+
|
| 104 |
+
def chunk_documents(self, documents: list[Document]) -> list[Document]:
|
| 105 |
"""
|
| 106 |
Split documents into chunks for RAG retrieval.
|
| 107 |
|
|
|
|
| 112 |
List of chunked documents with preserved metadata
|
| 113 |
"""
|
| 114 |
print(f"\nChunking documents (size={self.chunk_size}, overlap={self.chunk_overlap})...")
|
| 115 |
+
|
| 116 |
chunks = self.text_splitter.split_documents(documents)
|
| 117 |
+
|
| 118 |
if not chunks:
|
| 119 |
print("WARN: No chunks generated from documents")
|
| 120 |
return chunks
|
| 121 |
+
|
| 122 |
# Add chunk index to metadata
|
| 123 |
for i, chunk in enumerate(chunks):
|
| 124 |
chunk.metadata['chunk_id'] = i
|
| 125 |
+
|
| 126 |
print(f"OK: Created {len(chunks)} chunks from {len(documents)} pages")
|
| 127 |
print(f" Average chunk size: {sum(len(c.page_content) for c in chunks) // len(chunks)} characters")
|
| 128 |
+
|
| 129 |
return chunks
|
| 130 |
+
|
| 131 |
def create_vector_store(
|
| 132 |
self,
|
| 133 |
+
chunks: list[Document],
|
| 134 |
embedding_model,
|
| 135 |
store_name: str = "medical_knowledge"
|
| 136 |
) -> FAISS:
|
|
|
|
| 148 |
print(f"\nCreating vector store: {store_name}")
|
| 149 |
print(f"Generating embeddings for {len(chunks)} chunks...")
|
| 150 |
print("(This may take a few minutes...)")
|
| 151 |
+
|
| 152 |
# Create FAISS vector store
|
| 153 |
vector_store = FAISS.from_documents(
|
| 154 |
documents=chunks,
|
| 155 |
embedding=embedding_model
|
| 156 |
)
|
| 157 |
+
|
| 158 |
# Save to disk
|
| 159 |
save_path = self.vector_store_path / f"{store_name}.faiss"
|
| 160 |
vector_store.save_local(str(self.vector_store_path), index_name=store_name)
|
| 161 |
+
|
| 162 |
print(f"OK: Vector store created and saved to: {save_path}")
|
| 163 |
+
|
| 164 |
return vector_store
|
| 165 |
+
|
| 166 |
def load_vector_store(
|
| 167 |
self,
|
| 168 |
embedding_model,
|
| 169 |
store_name: str = "medical_knowledge"
|
| 170 |
+
) -> FAISS | None:
|
| 171 |
"""
|
| 172 |
Load existing vector store from disk.
|
| 173 |
|
|
|
|
| 179 |
FAISS vector store or None if not found
|
| 180 |
"""
|
| 181 |
store_path = self.vector_store_path / f"{store_name}.faiss"
|
| 182 |
+
|
| 183 |
if not store_path.exists():
|
| 184 |
print(f"WARN: Vector store not found: {store_path}")
|
| 185 |
return None
|
| 186 |
+
|
| 187 |
try:
|
| 188 |
# SECURITY NOTE: allow_dangerous_deserialization=True uses pickle.
|
| 189 |
# Only load vector stores from trusted, locally-built sources.
|
|
|
|
| 196 |
)
|
| 197 |
print(f"OK: Loaded vector store from: {store_path}")
|
| 198 |
return vector_store
|
| 199 |
+
|
| 200 |
except Exception as e:
|
| 201 |
print(f"ERROR: Error loading vector store: {e}")
|
| 202 |
return None
|
| 203 |
+
|
| 204 |
def create_retrievers(
|
| 205 |
self,
|
| 206 |
embedding_model,
|
|
|
|
| 223 |
vector_store = self.load_vector_store(embedding_model, store_name)
|
| 224 |
else:
|
| 225 |
vector_store = None
|
| 226 |
+
|
| 227 |
# If not found, create new one
|
| 228 |
if vector_store is None:
|
| 229 |
print("\nBuilding new vector store from PDFs...")
|
| 230 |
documents = self.load_pdfs()
|
| 231 |
+
|
| 232 |
if not documents:
|
| 233 |
print("WARN: No documents to process. Please add PDF files.")
|
| 234 |
return {}
|
| 235 |
+
|
| 236 |
chunks = self.chunk_documents(documents)
|
| 237 |
vector_store = self.create_vector_store(chunks, embedding_model, store_name)
|
| 238 |
+
|
| 239 |
# Create specialized retrievers
|
| 240 |
retrievers = {
|
| 241 |
"disease_explainer": vector_store.as_retriever(
|
|
|
|
| 251 |
search_kwargs={"k": 5}
|
| 252 |
)
|
| 253 |
}
|
| 254 |
+
|
| 255 |
print(f"\nOK: Created {len(retrievers)} specialized retrievers")
|
| 256 |
return retrievers
|
| 257 |
|
|
|
|
| 271 |
print("=" * 60)
|
| 272 |
print("Setting up Medical Knowledge Base")
|
| 273 |
print("=" * 60)
|
| 274 |
+
|
| 275 |
# Use configured embedding provider from environment
|
| 276 |
if use_configured_embeddings and embedding_model is None:
|
| 277 |
embedding_model = get_embedding_model()
|
| 278 |
print(" > Embeddings model loaded")
|
| 279 |
elif embedding_model is None:
|
| 280 |
raise ValueError("Must provide embedding_model or set use_configured_embeddings=True")
|
| 281 |
+
|
| 282 |
processor = PDFProcessor()
|
| 283 |
retrievers = processor.create_retrievers(
|
| 284 |
embedding_model,
|
| 285 |
store_name="medical_knowledge",
|
| 286 |
force_rebuild=force_rebuild
|
| 287 |
)
|
| 288 |
+
|
| 289 |
if retrievers:
|
| 290 |
print("\nOK: Knowledge base setup complete!")
|
| 291 |
else:
|
| 292 |
print("\nWARN: Knowledge base setup incomplete. Add PDFs and try again.")
|
| 293 |
+
|
| 294 |
print("=" * 60)
|
| 295 |
+
|
| 296 |
return retrievers
|
| 297 |
|
| 298 |
|
|
|
|
| 319 |
# Test PDF processing
|
| 320 |
import sys
|
| 321 |
from pathlib import Path
|
| 322 |
+
|
| 323 |
# Add parent directory to path for imports
|
| 324 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 325 |
+
|
| 326 |
print("\n" + "="*70)
|
| 327 |
print("MediGuard AI - PDF Knowledge Base Builder")
|
| 328 |
print("="*70)
|
| 329 |
print("\nUsing configured embedding provider from .env")
|
| 330 |
print(" EMBEDDING_PROVIDER options: google (default), huggingface, ollama")
|
| 331 |
print("="*70)
|
| 332 |
+
|
| 333 |
retrievers = setup_knowledge_base(
|
| 334 |
use_configured_embeddings=True, # Use configured provider
|
| 335 |
force_rebuild=False
|
| 336 |
)
|
| 337 |
+
|
| 338 |
if retrievers:
|
| 339 |
print("\nOK: PDF processing test successful!")
|
| 340 |
print(f"Available retrievers: {list(retrievers.keys())}")
|
src/repositories/analysis.py
CHANGED
|
@@ -4,8 +4,6 @@ MediGuard AI — Analysis repository (data-access layer).
|
|
| 4 |
|
| 5 |
from __future__ import annotations
|
| 6 |
|
| 7 |
-
from typing import List, Optional
|
| 8 |
-
|
| 9 |
from sqlalchemy.orm import Session
|
| 10 |
|
| 11 |
from src.models.analysis import PatientAnalysis
|
|
@@ -22,14 +20,14 @@ class AnalysisRepository:
|
|
| 22 |
self.db.flush()
|
| 23 |
return analysis
|
| 24 |
|
| 25 |
-
def get_by_request_id(self, request_id: str) ->
|
| 26 |
return (
|
| 27 |
self.db.query(PatientAnalysis)
|
| 28 |
.filter(PatientAnalysis.request_id == request_id)
|
| 29 |
.first()
|
| 30 |
)
|
| 31 |
|
| 32 |
-
def list_recent(self, limit: int = 20) ->
|
| 33 |
return (
|
| 34 |
self.db.query(PatientAnalysis)
|
| 35 |
.order_by(PatientAnalysis.created_at.desc())
|
|
|
|
| 4 |
|
| 5 |
from __future__ import annotations
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from sqlalchemy.orm import Session
|
| 8 |
|
| 9 |
from src.models.analysis import PatientAnalysis
|
|
|
|
| 20 |
self.db.flush()
|
| 21 |
return analysis
|
| 22 |
|
| 23 |
+
def get_by_request_id(self, request_id: str) -> PatientAnalysis | None:
|
| 24 |
return (
|
| 25 |
self.db.query(PatientAnalysis)
|
| 26 |
.filter(PatientAnalysis.request_id == request_id)
|
| 27 |
.first()
|
| 28 |
)
|
| 29 |
|
| 30 |
+
def list_recent(self, limit: int = 20) -> list[PatientAnalysis]:
|
| 31 |
return (
|
| 32 |
self.db.query(PatientAnalysis)
|
| 33 |
.order_by(PatientAnalysis.created_at.desc())
|
src/repositories/document.py
CHANGED
|
@@ -4,8 +4,6 @@ MediGuard AI — Document repository.
|
|
| 4 |
|
| 5 |
from __future__ import annotations
|
| 6 |
|
| 7 |
-
from typing import List, Optional
|
| 8 |
-
|
| 9 |
from sqlalchemy.orm import Session
|
| 10 |
|
| 11 |
from src.models.analysis import MedicalDocument
|
|
@@ -33,10 +31,10 @@ class DocumentRepository:
|
|
| 33 |
self.db.flush()
|
| 34 |
return doc
|
| 35 |
|
| 36 |
-
def get_by_id(self, doc_id: str) ->
|
| 37 |
return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first()
|
| 38 |
|
| 39 |
-
def list_all(self, limit: int = 100) ->
|
| 40 |
return (
|
| 41 |
self.db.query(MedicalDocument)
|
| 42 |
.order_by(MedicalDocument.created_at.desc())
|
|
|
|
| 4 |
|
| 5 |
from __future__ import annotations
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from sqlalchemy.orm import Session
|
| 8 |
|
| 9 |
from src.models.analysis import MedicalDocument
|
|
|
|
| 31 |
self.db.flush()
|
| 32 |
return doc
|
| 33 |
|
| 34 |
+
def get_by_id(self, doc_id: str) -> MedicalDocument | None:
|
| 35 |
return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first()
|
| 36 |
|
| 37 |
+
def list_all(self, limit: int = 100) -> list[MedicalDocument]:
|
| 38 |
return (
|
| 39 |
self.db.query(MedicalDocument)
|
| 40 |
.order_by(MedicalDocument.created_at.desc())
|
src/routers/analyze.py
CHANGED
|
@@ -12,8 +12,8 @@ import logging
|
|
| 12 |
import time
|
| 13 |
import uuid
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
-
from datetime import
|
| 16 |
-
from typing import Any
|
| 17 |
|
| 18 |
from fastapi import APIRouter, HTTPException, Request
|
| 19 |
|
|
@@ -30,7 +30,7 @@ router = APIRouter(prefix="/analyze", tags=["analysis"])
|
|
| 30 |
_executor = ThreadPoolExecutor(max_workers=4)
|
| 31 |
|
| 32 |
|
| 33 |
-
def _score_disease_heuristic(biomarkers:
|
| 34 |
"""Rule-based disease scoring (NOT ML prediction)."""
|
| 35 |
scores = {
|
| 36 |
"Diabetes": 0.0,
|
|
@@ -39,7 +39,7 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 39 |
"Thrombocytopenia": 0.0,
|
| 40 |
"Thalassemia": 0.0
|
| 41 |
}
|
| 42 |
-
|
| 43 |
# Diabetes indicators
|
| 44 |
glucose = biomarkers.get("Glucose")
|
| 45 |
hba1c = biomarkers.get("HbA1c")
|
|
@@ -49,7 +49,7 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 49 |
scores["Diabetes"] += 0.2
|
| 50 |
if hba1c is not None and hba1c >= 6.5:
|
| 51 |
scores["Diabetes"] += 0.5
|
| 52 |
-
|
| 53 |
# Anemia indicators
|
| 54 |
hemoglobin = biomarkers.get("Hemoglobin")
|
| 55 |
mcv = biomarkers.get("Mean Corpuscular Volume", biomarkers.get("MCV"))
|
|
@@ -59,7 +59,7 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 59 |
scores["Anemia"] += 0.2
|
| 60 |
if mcv is not None and mcv < 80:
|
| 61 |
scores["Anemia"] += 0.2
|
| 62 |
-
|
| 63 |
# Heart disease indicators
|
| 64 |
cholesterol = biomarkers.get("Cholesterol")
|
| 65 |
troponin = biomarkers.get("Troponin")
|
|
@@ -70,32 +70,32 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 70 |
scores["Heart Disease"] += 0.6
|
| 71 |
if ldl is not None and ldl > 190:
|
| 72 |
scores["Heart Disease"] += 0.2
|
| 73 |
-
|
| 74 |
# Thrombocytopenia indicators
|
| 75 |
platelets = biomarkers.get("Platelets")
|
| 76 |
if platelets is not None and platelets < 150000:
|
| 77 |
scores["Thrombocytopenia"] += 0.6
|
| 78 |
if platelets is not None and platelets < 50000:
|
| 79 |
scores["Thrombocytopenia"] += 0.3
|
| 80 |
-
|
| 81 |
# Thalassemia indicators
|
| 82 |
if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
|
| 83 |
scores["Thalassemia"] += 0.4
|
| 84 |
-
|
| 85 |
# Find top prediction
|
| 86 |
top_disease = max(scores, key=scores.get)
|
| 87 |
confidence = min(scores[top_disease], 1.0)
|
| 88 |
-
|
| 89 |
if confidence == 0.0:
|
| 90 |
top_disease = "Undetermined"
|
| 91 |
-
|
| 92 |
# Normalize probabilities
|
| 93 |
total = sum(scores.values())
|
| 94 |
if total > 0:
|
| 95 |
probabilities = {k: v / total for k, v in scores.items()}
|
| 96 |
else:
|
| 97 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 98 |
-
|
| 99 |
return {
|
| 100 |
"disease": top_disease,
|
| 101 |
"confidence": confidence,
|
|
@@ -105,16 +105,16 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
|
|
| 105 |
|
| 106 |
async def _run_guild_analysis(
|
| 107 |
request: Request,
|
| 108 |
-
biomarkers:
|
| 109 |
-
patient_ctx:
|
| 110 |
-
extracted_biomarkers:
|
| 111 |
) -> AnalysisResponse:
|
| 112 |
"""Execute the ClinicalInsightGuild and build the response envelope."""
|
| 113 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 114 |
t0 = time.time()
|
| 115 |
|
| 116 |
ragbot = getattr(request.app.state, "ragbot_service", None)
|
| 117 |
-
if ragbot is None
|
| 118 |
raise HTTPException(status_code=503, detail="Analysis service unavailable. Please wait for initialization.")
|
| 119 |
|
| 120 |
# Generate disease prediction
|
|
@@ -122,15 +122,16 @@ async def _run_guild_analysis(
|
|
| 122 |
|
| 123 |
try:
|
| 124 |
# Run sync function in thread pool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
loop = asyncio.get_running_loop()
|
| 126 |
result = await loop.run_in_executor(
|
| 127 |
_executor,
|
| 128 |
-
lambda: ragbot.
|
| 129 |
-
biomarkers=biomarkers,
|
| 130 |
-
patient_context=patient_ctx,
|
| 131 |
-
model_prediction=model_prediction,
|
| 132 |
-
extracted_biomarkers=extracted_biomarkers
|
| 133 |
-
)
|
| 134 |
)
|
| 135 |
except Exception as exc:
|
| 136 |
logger.exception("Guild analysis failed: %s", exc)
|
|
@@ -142,20 +143,15 @@ async def _run_guild_analysis(
|
|
| 142 |
elapsed = (time.time() - t0) * 1000
|
| 143 |
|
| 144 |
# Build response from result
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
conversational_summary = result.get('conversational_summary')
|
| 150 |
-
else:
|
| 151 |
-
prediction = getattr(result, 'prediction', None)
|
| 152 |
-
analysis = getattr(result, 'analysis', None)
|
| 153 |
-
conversational_summary = getattr(result, 'conversational_summary', None)
|
| 154 |
|
| 155 |
return AnalysisResponse(
|
| 156 |
status="success",
|
| 157 |
request_id=request_id,
|
| 158 |
-
timestamp=datetime.now(
|
| 159 |
extracted_biomarkers=extracted_biomarkers,
|
| 160 |
input_biomarkers=biomarkers,
|
| 161 |
patient_context=patient_ctx,
|
|
|
|
| 12 |
import time
|
| 13 |
import uuid
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
+
from datetime import UTC, datetime
|
| 16 |
+
from typing import Any
|
| 17 |
|
| 18 |
from fastapi import APIRouter, HTTPException, Request
|
| 19 |
|
|
|
|
| 30 |
_executor = ThreadPoolExecutor(max_workers=4)
|
| 31 |
|
| 32 |
|
| 33 |
+
def _score_disease_heuristic(biomarkers: dict[str, float]) -> dict[str, Any]:
|
| 34 |
"""Rule-based disease scoring (NOT ML prediction)."""
|
| 35 |
scores = {
|
| 36 |
"Diabetes": 0.0,
|
|
|
|
| 39 |
"Thrombocytopenia": 0.0,
|
| 40 |
"Thalassemia": 0.0
|
| 41 |
}
|
| 42 |
+
|
| 43 |
# Diabetes indicators
|
| 44 |
glucose = biomarkers.get("Glucose")
|
| 45 |
hba1c = biomarkers.get("HbA1c")
|
|
|
|
| 49 |
scores["Diabetes"] += 0.2
|
| 50 |
if hba1c is not None and hba1c >= 6.5:
|
| 51 |
scores["Diabetes"] += 0.5
|
| 52 |
+
|
| 53 |
# Anemia indicators
|
| 54 |
hemoglobin = biomarkers.get("Hemoglobin")
|
| 55 |
mcv = biomarkers.get("Mean Corpuscular Volume", biomarkers.get("MCV"))
|
|
|
|
| 59 |
scores["Anemia"] += 0.2
|
| 60 |
if mcv is not None and mcv < 80:
|
| 61 |
scores["Anemia"] += 0.2
|
| 62 |
+
|
| 63 |
# Heart disease indicators
|
| 64 |
cholesterol = biomarkers.get("Cholesterol")
|
| 65 |
troponin = biomarkers.get("Troponin")
|
|
|
|
| 70 |
scores["Heart Disease"] += 0.6
|
| 71 |
if ldl is not None and ldl > 190:
|
| 72 |
scores["Heart Disease"] += 0.2
|
| 73 |
+
|
| 74 |
# Thrombocytopenia indicators
|
| 75 |
platelets = biomarkers.get("Platelets")
|
| 76 |
if platelets is not None and platelets < 150000:
|
| 77 |
scores["Thrombocytopenia"] += 0.6
|
| 78 |
if platelets is not None and platelets < 50000:
|
| 79 |
scores["Thrombocytopenia"] += 0.3
|
| 80 |
+
|
| 81 |
# Thalassemia indicators
|
| 82 |
if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
|
| 83 |
scores["Thalassemia"] += 0.4
|
| 84 |
+
|
| 85 |
# Find top prediction
|
| 86 |
top_disease = max(scores, key=scores.get)
|
| 87 |
confidence = min(scores[top_disease], 1.0)
|
| 88 |
+
|
| 89 |
if confidence == 0.0:
|
| 90 |
top_disease = "Undetermined"
|
| 91 |
+
|
| 92 |
# Normalize probabilities
|
| 93 |
total = sum(scores.values())
|
| 94 |
if total > 0:
|
| 95 |
probabilities = {k: v / total for k, v in scores.items()}
|
| 96 |
else:
|
| 97 |
probabilities = {k: 1.0 / len(scores) for k in scores}
|
| 98 |
+
|
| 99 |
return {
|
| 100 |
"disease": top_disease,
|
| 101 |
"confidence": confidence,
|
|
|
|
| 105 |
|
| 106 |
async def _run_guild_analysis(
|
| 107 |
request: Request,
|
| 108 |
+
biomarkers: dict[str, float],
|
| 109 |
+
patient_ctx: dict[str, Any],
|
| 110 |
+
extracted_biomarkers: dict[str, float] | None = None,
|
| 111 |
) -> AnalysisResponse:
|
| 112 |
"""Execute the ClinicalInsightGuild and build the response envelope."""
|
| 113 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 114 |
t0 = time.time()
|
| 115 |
|
| 116 |
ragbot = getattr(request.app.state, "ragbot_service", None)
|
| 117 |
+
if ragbot is None:
|
| 118 |
raise HTTPException(status_code=503, detail="Analysis service unavailable. Please wait for initialization.")
|
| 119 |
|
| 120 |
# Generate disease prediction
|
|
|
|
| 122 |
|
| 123 |
try:
|
| 124 |
# Run sync function in thread pool
|
| 125 |
+
from src.state import PatientInput
|
| 126 |
+
patient_input = PatientInput(
|
| 127 |
+
biomarkers=biomarkers,
|
| 128 |
+
patient_context=patient_ctx,
|
| 129 |
+
model_prediction=model_prediction
|
| 130 |
+
)
|
| 131 |
loop = asyncio.get_running_loop()
|
| 132 |
result = await loop.run_in_executor(
|
| 133 |
_executor,
|
| 134 |
+
lambda: ragbot.run(patient_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
except Exception as exc:
|
| 137 |
logger.exception("Guild analysis failed: %s", exc)
|
|
|
|
| 143 |
elapsed = (time.time() - t0) * 1000
|
| 144 |
|
| 145 |
# Build response from result
|
| 146 |
+
prediction = result.get('model_prediction')
|
| 147 |
+
analysis = result.get('final_response', {})
|
| 148 |
+
# Try to extract the conversational_summary if it's there
|
| 149 |
+
conversational_summary = analysis.get('conversational_summary') if isinstance(analysis, dict) else str(analysis)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
return AnalysisResponse(
|
| 152 |
status="success",
|
| 153 |
request_id=request_id,
|
| 154 |
+
timestamp=datetime.now(UTC).isoformat(),
|
| 155 |
extracted_biomarkers=extracted_biomarkers,
|
| 156 |
input_biomarkers=biomarkers,
|
| 157 |
patient_context=patient_ctx,
|
src/routers/ask.py
CHANGED
|
@@ -12,13 +12,12 @@ import json
|
|
| 12 |
import logging
|
| 13 |
import time
|
| 14 |
import uuid
|
| 15 |
-
from
|
| 16 |
-
from typing import AsyncGenerator
|
| 17 |
|
| 18 |
from fastapi import APIRouter, HTTPException, Request
|
| 19 |
from fastapi.responses import StreamingResponse
|
| 20 |
|
| 21 |
-
from src.schemas.schemas import AskRequest, AskResponse
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
router = APIRouter(tags=["ask"])
|
|
@@ -81,12 +80,12 @@ async def _stream_rag_response(
|
|
| 81 |
- error: Error information
|
| 82 |
"""
|
| 83 |
t0 = time.time()
|
| 84 |
-
|
| 85 |
try:
|
| 86 |
# Send initial status
|
| 87 |
yield f"event: status\ndata: {json.dumps({'stage': 'guardrail', 'message': 'Validating query...'})}\n\n"
|
| 88 |
await asyncio.sleep(0) # Allow event loop to flush
|
| 89 |
-
|
| 90 |
# Run the RAG pipeline (synchronous, but we yield progress)
|
| 91 |
loop = asyncio.get_running_loop()
|
| 92 |
result = await loop.run_in_executor(
|
|
@@ -97,16 +96,16 @@ async def _stream_rag_response(
|
|
| 97 |
patient_context=patient_context,
|
| 98 |
)
|
| 99 |
)
|
| 100 |
-
|
| 101 |
# Send retrieval metadata
|
| 102 |
yield f"event: metadata\ndata: {json.dumps({'documents_retrieved': len(result.get('retrieved_documents', [])), 'documents_relevant': len(result.get('relevant_documents', [])), 'guardrail_score': result.get('guardrail_score')})}\n\n"
|
| 103 |
await asyncio.sleep(0)
|
| 104 |
-
|
| 105 |
# Stream the answer token by token for smooth UI
|
| 106 |
answer = result.get("final_answer", "")
|
| 107 |
if answer:
|
| 108 |
yield f"event: status\ndata: {json.dumps({'stage': 'generating', 'message': 'Generating response...'})}\n\n"
|
| 109 |
-
|
| 110 |
# Simulate streaming by chunking the response
|
| 111 |
words = answer.split()
|
| 112 |
chunk_size = 3 # Send 3 words at a time
|
|
@@ -116,11 +115,11 @@ async def _stream_rag_response(
|
|
| 116 |
chunk += " "
|
| 117 |
yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
|
| 118 |
await asyncio.sleep(0.02) # Small delay for visual streaming effect
|
| 119 |
-
|
| 120 |
# Send completion
|
| 121 |
elapsed = (time.time() - t0) * 1000
|
| 122 |
yield f"event: done\ndata: {json.dumps({'request_id': request_id, 'processing_time_ms': round(elapsed, 1), 'status': 'success'})}\n\n"
|
| 123 |
-
|
| 124 |
except Exception as exc:
|
| 125 |
logger.exception("Streaming RAG failed: %s", exc)
|
| 126 |
yield f"event: error\ndata: {json.dumps({'error': str(exc), 'request_id': request_id})}\n\n"
|
|
@@ -154,9 +153,9 @@ async def ask_medical_question_stream(body: AskRequest, request: Request):
|
|
| 154 |
rag_service = getattr(request.app.state, "rag_service", None)
|
| 155 |
if rag_service is None:
|
| 156 |
raise HTTPException(status_code=503, detail="RAG service unavailable")
|
| 157 |
-
|
| 158 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 159 |
-
|
| 160 |
return StreamingResponse(
|
| 161 |
_stream_rag_response(
|
| 162 |
rag_service,
|
|
@@ -172,3 +171,17 @@ async def ask_medical_question_stream(body: AskRequest, request: Request):
|
|
| 172 |
"X-Request-ID": request_id,
|
| 173 |
},
|
| 174 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import logging
|
| 13 |
import time
|
| 14 |
import uuid
|
| 15 |
+
from collections.abc import AsyncGenerator
|
|
|
|
| 16 |
|
| 17 |
from fastapi import APIRouter, HTTPException, Request
|
| 18 |
from fastapi.responses import StreamingResponse
|
| 19 |
|
| 20 |
+
from src.schemas.schemas import AskRequest, AskResponse, FeedbackRequest, FeedbackResponse
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
router = APIRouter(tags=["ask"])
|
|
|
|
| 80 |
- error: Error information
|
| 81 |
"""
|
| 82 |
t0 = time.time()
|
| 83 |
+
|
| 84 |
try:
|
| 85 |
# Send initial status
|
| 86 |
yield f"event: status\ndata: {json.dumps({'stage': 'guardrail', 'message': 'Validating query...'})}\n\n"
|
| 87 |
await asyncio.sleep(0) # Allow event loop to flush
|
| 88 |
+
|
| 89 |
# Run the RAG pipeline (synchronous, but we yield progress)
|
| 90 |
loop = asyncio.get_running_loop()
|
| 91 |
result = await loop.run_in_executor(
|
|
|
|
| 96 |
patient_context=patient_context,
|
| 97 |
)
|
| 98 |
)
|
| 99 |
+
|
| 100 |
# Send retrieval metadata
|
| 101 |
yield f"event: metadata\ndata: {json.dumps({'documents_retrieved': len(result.get('retrieved_documents', [])), 'documents_relevant': len(result.get('relevant_documents', [])), 'guardrail_score': result.get('guardrail_score')})}\n\n"
|
| 102 |
await asyncio.sleep(0)
|
| 103 |
+
|
| 104 |
# Stream the answer token by token for smooth UI
|
| 105 |
answer = result.get("final_answer", "")
|
| 106 |
if answer:
|
| 107 |
yield f"event: status\ndata: {json.dumps({'stage': 'generating', 'message': 'Generating response...'})}\n\n"
|
| 108 |
+
|
| 109 |
# Simulate streaming by chunking the response
|
| 110 |
words = answer.split()
|
| 111 |
chunk_size = 3 # Send 3 words at a time
|
|
|
|
| 115 |
chunk += " "
|
| 116 |
yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
|
| 117 |
await asyncio.sleep(0.02) # Small delay for visual streaming effect
|
| 118 |
+
|
| 119 |
# Send completion
|
| 120 |
elapsed = (time.time() - t0) * 1000
|
| 121 |
yield f"event: done\ndata: {json.dumps({'request_id': request_id, 'processing_time_ms': round(elapsed, 1), 'status': 'success'})}\n\n"
|
| 122 |
+
|
| 123 |
except Exception as exc:
|
| 124 |
logger.exception("Streaming RAG failed: %s", exc)
|
| 125 |
yield f"event: error\ndata: {json.dumps({'error': str(exc), 'request_id': request_id})}\n\n"
|
|
|
|
| 153 |
rag_service = getattr(request.app.state, "rag_service", None)
|
| 154 |
if rag_service is None:
|
| 155 |
raise HTTPException(status_code=503, detail="RAG service unavailable")
|
| 156 |
+
|
| 157 |
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 158 |
+
|
| 159 |
return StreamingResponse(
|
| 160 |
_stream_rag_response(
|
| 161 |
rag_service,
|
|
|
|
| 171 |
"X-Request-ID": request_id,
|
| 172 |
},
|
| 173 |
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@router.post("/feedback", response_model=FeedbackResponse)
|
| 177 |
+
async def submit_feedback(body: FeedbackRequest, request: Request):
|
| 178 |
+
"""Submit user feedback for an analysis or RAG response."""
|
| 179 |
+
tracer = getattr(request.app.state, "tracer", None)
|
| 180 |
+
if tracer:
|
| 181 |
+
tracer.score(
|
| 182 |
+
trace_id=body.request_id,
|
| 183 |
+
name="user-feedback",
|
| 184 |
+
value=body.score,
|
| 185 |
+
comment=body.comment
|
| 186 |
+
)
|
| 187 |
+
return FeedbackResponse(request_id=body.request_id)
|
src/routers/health.py
CHANGED
|
@@ -7,7 +7,7 @@ Provides /health and /health/ready with per-service checks.
|
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
import time
|
| 10 |
-
from datetime import
|
| 11 |
|
| 12 |
from fastapi import APIRouter, Request
|
| 13 |
|
|
@@ -23,7 +23,7 @@ async def health_check(request: Request) -> HealthResponse:
|
|
| 23 |
uptime = time.time() - getattr(app_state, "start_time", time.time())
|
| 24 |
return HealthResponse(
|
| 25 |
status="healthy",
|
| 26 |
-
timestamp=datetime.now(
|
| 27 |
version=getattr(app_state, "version", "2.0.0"),
|
| 28 |
uptime_seconds=round(uptime, 2),
|
| 29 |
)
|
|
@@ -39,9 +39,10 @@ async def readiness_check(request: Request) -> HealthResponse:
|
|
| 39 |
|
| 40 |
# --- PostgreSQL ---
|
| 41 |
try:
|
| 42 |
-
from src.database import get_engine
|
| 43 |
from sqlalchemy import text
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
if engine is not None:
|
| 46 |
t0 = time.time()
|
| 47 |
with engine.connect() as conn:
|
|
@@ -86,9 +87,10 @@ async def readiness_check(request: Request) -> HealthResponse:
|
|
| 86 |
ollama = getattr(app_state, "ollama_client", None)
|
| 87 |
if ollama is not None:
|
| 88 |
t0 = time.time()
|
| 89 |
-
|
| 90 |
latency = (time.time() - t0) * 1000
|
| 91 |
-
|
|
|
|
| 92 |
else:
|
| 93 |
services.append(ServiceHealth(name="ollama", status="unavailable"))
|
| 94 |
except Exception as exc:
|
|
@@ -126,7 +128,7 @@ async def readiness_check(request: Request) -> HealthResponse:
|
|
| 126 |
|
| 127 |
return HealthResponse(
|
| 128 |
status=overall,
|
| 129 |
-
timestamp=datetime.now(
|
| 130 |
version=getattr(app_state, "version", "2.0.0"),
|
| 131 |
uptime_seconds=round(uptime, 2),
|
| 132 |
services=services,
|
|
|
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
import time
|
| 10 |
+
from datetime import UTC, datetime
|
| 11 |
|
| 12 |
from fastapi import APIRouter, Request
|
| 13 |
|
|
|
|
| 23 |
uptime = time.time() - getattr(app_state, "start_time", time.time())
|
| 24 |
return HealthResponse(
|
| 25 |
status="healthy",
|
| 26 |
+
timestamp=datetime.now(UTC).isoformat(),
|
| 27 |
version=getattr(app_state, "version", "2.0.0"),
|
| 28 |
uptime_seconds=round(uptime, 2),
|
| 29 |
)
|
|
|
|
| 39 |
|
| 40 |
# --- PostgreSQL ---
|
| 41 |
try:
|
|
|
|
| 42 |
from sqlalchemy import text
|
| 43 |
+
|
| 44 |
+
from src.database import _engine
|
| 45 |
+
engine = _engine()
|
| 46 |
if engine is not None:
|
| 47 |
t0 = time.time()
|
| 48 |
with engine.connect() as conn:
|
|
|
|
| 87 |
ollama = getattr(app_state, "ollama_client", None)
|
| 88 |
if ollama is not None:
|
| 89 |
t0 = time.time()
|
| 90 |
+
health_info = ollama.health()
|
| 91 |
latency = (time.time() - t0) * 1000
|
| 92 |
+
is_healthy = isinstance(health_info, dict) and health_info.get("status") == "ok"
|
| 93 |
+
services.append(ServiceHealth(name="ollama", status="ok" if is_healthy else "degraded", latency_ms=round(latency, 1)))
|
| 94 |
else:
|
| 95 |
services.append(ServiceHealth(name="ollama", status="unavailable"))
|
| 96 |
except Exception as exc:
|
|
|
|
| 128 |
|
| 129 |
return HealthResponse(
|
| 130 |
status=overall,
|
| 131 |
+
timestamp=datetime.now(UTC).isoformat(),
|
| 132 |
version=getattr(app_state, "version", "2.0.0"),
|
| 133 |
uptime_seconds=round(uptime, 2),
|
| 134 |
services=services,
|