Nikhil Pravin Pise commited on
Commit
696f787
·
1 Parent(s): fd5543a

Fix codebase issues: linting, types, tests, and security.

Browse files

- Resolved over 3,000 ruff linting violations
- Enforced strict type checking with mypy
- Fixed infinite loop in pytest suite by migrating obsolete tests
- Remediated security warnings flagged by Bandit

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Makefile +4 -2
  2. airflow/dags/ingest_pdfs.py +7 -4
  3. alembic/env.py +9 -11
  4. alembic/versions/001_initial.py +81 -0
  5. api/app/main.py +10 -13
  6. api/app/routes/analyze.py +31 -37
  7. api/app/routes/biomarkers.py +16 -17
  8. api/app/routes/health.py +12 -15
  9. api/app/services/extraction.py +27 -27
  10. api/app/services/ragbot.py +69 -62
  11. archive/evolution/__init__.py +11 -17
  12. archive/evolution/director.py +55 -55
  13. archive/evolution/pareto.py +48 -46
  14. archive/sop_evolution.py +2 -1
  15. {tests → archive/tests}/test_evolution_loop.py +37 -37
  16. {tests → archive/tests}/test_evolution_quick.py +14 -13
  17. docker-compose.yml +20 -0
  18. gradio_launcher.py +24 -0
  19. huggingface/app.py +180 -133
  20. pytest.ini +3 -0
  21. requirements.txt +0 -41
  22. scripts/chat.py +73 -73
  23. scripts/monitor_test.py +1 -1
  24. scripts/setup_embeddings.py +16 -16
  25. scripts/test_chat_demo.py +4 -5
  26. scripts/test_extraction.py +8 -7
  27. src/agents/biomarker_analyzer.py +23 -23
  28. src/agents/biomarker_linker.py +41 -41
  29. src/agents/clinical_guidelines.py +43 -42
  30. src/agents/confidence_assessor.py +46 -46
  31. src/agents/disease_explainer.py +36 -34
  32. src/agents/response_synthesizer.py +52 -51
  33. src/biomarker_normalization.py +1 -2
  34. src/biomarker_validator.py +31 -31
  35. src/config.py +15 -14
  36. src/database.py +2 -2
  37. src/dependencies.py +0 -3
  38. src/evaluation/__init__.py +7 -7
  39. src/evaluation/evaluators.py +72 -70
  40. src/exceptions.py +2 -3
  41. src/gradio_app.py +65 -25
  42. src/llm_config.py +68 -67
  43. src/main.py +20 -23
  44. src/middlewares.py +24 -23
  45. src/pdf_processor.py +52 -53
  46. src/repositories/analysis.py +2 -4
  47. src/repositories/document.py +2 -4
  48. src/routers/analyze.py +28 -32
  49. src/routers/ask.py +25 -12
  50. src/routers/health.py +9 -7
Makefile CHANGED
@@ -117,12 +117,14 @@ index-pdfs: ## Parse and index all medical PDFs
117
  from pathlib import Path; \
118
  from src.services.pdf_parser.service import make_pdf_parser_service; \
119
  from src.services.indexing.service import IndexingService; \
 
120
  from src.services.embeddings.service import make_embedding_service; \
121
  from src.services.opensearch.client import make_opensearch_client; \
122
  parser = make_pdf_parser_service(); \
123
- idx = IndexingService(make_embedding_service(), make_opensearch_client()); \
 
124
  docs = parser.parse_directory(Path('data/medical_pdfs')); \
125
- [idx.index_text(d.full_text, {'title': d.filename}) for d in docs if d.full_text]; \
126
  print(f'Indexed {len(docs)} documents')"
127
 
128
  # ---------------------------------------------------------------------------
 
117
  from pathlib import Path; \
118
  from src.services.pdf_parser.service import make_pdf_parser_service; \
119
  from src.services.indexing.service import IndexingService; \
120
+ from src.services.indexing.text_chunker import MedicalTextChunker; \
121
  from src.services.embeddings.service import make_embedding_service; \
122
  from src.services.opensearch.client import make_opensearch_client; \
123
  parser = make_pdf_parser_service(); \
124
+ chunker = MedicalTextChunker(); \
125
+ idx = IndexingService(chunker, make_embedding_service(), make_opensearch_client()); \
126
  docs = parser.parse_directory(Path('data/medical_pdfs')); \
127
+ [idx.index_text(d.full_text, title=d.filename, source_file=d.filename) for d in docs if d.full_text]; \
128
  print(f'Indexed {len(docs)} documents')"
129
 
130
  # ---------------------------------------------------------------------------
airflow/dags/ingest_pdfs.py CHANGED
@@ -9,9 +9,10 @@ from __future__ import annotations
9
 
10
  from datetime import datetime, timedelta
11
 
12
- from airflow import DAG
13
  from airflow.operators.python import PythonOperator
14
 
 
 
15
  default_args = {
16
  "owner": "mediguard",
17
  "retries": 2,
@@ -26,23 +27,25 @@ def _ingest_pdfs(**kwargs):
26
 
27
  from src.services.embeddings.service import make_embedding_service
28
  from src.services.indexing.service import IndexingService
 
29
  from src.services.opensearch.client import make_opensearch_client
30
  from src.services.pdf_parser.service import make_pdf_parser_service
31
  from src.settings import get_settings
32
 
33
  settings = get_settings()
34
- pdf_dir = Path(settings.medical_pdfs.directory)
35
 
36
  parser = make_pdf_parser_service()
37
  embedding_svc = make_embedding_service()
38
  os_client = make_opensearch_client()
39
- indexing_svc = IndexingService(embedding_svc, os_client)
 
40
 
41
  docs = parser.parse_directory(pdf_dir)
42
  indexed = 0
43
  for doc in docs:
44
  if doc.full_text and not doc.error:
45
- indexing_svc.index_text(doc.full_text, {"title": doc.filename})
46
  indexed += 1
47
 
48
  print(f"Ingested {indexed}/{len(docs)} documents")
 
9
 
10
  from datetime import datetime, timedelta
11
 
 
12
  from airflow.operators.python import PythonOperator
13
 
14
+ from airflow import DAG
15
+
16
  default_args = {
17
  "owner": "mediguard",
18
  "retries": 2,
 
27
 
28
  from src.services.embeddings.service import make_embedding_service
29
  from src.services.indexing.service import IndexingService
30
+ from src.services.indexing.text_chunker import MedicalTextChunker
31
  from src.services.opensearch.client import make_opensearch_client
32
  from src.services.pdf_parser.service import make_pdf_parser_service
33
  from src.settings import get_settings
34
 
35
  settings = get_settings()
36
+ pdf_dir = Path(settings.pdf.pdf_directory)
37
 
38
  parser = make_pdf_parser_service()
39
  embedding_svc = make_embedding_service()
40
  os_client = make_opensearch_client()
41
+ chunker = MedicalTextChunker(target_words=settings.chunking.chunk_size, overlap_words=settings.chunking.chunk_overlap, min_words=settings.chunking.min_chunk_size)
42
+ indexing_svc = IndexingService(chunker, embedding_svc, os_client)
43
 
44
  docs = parser.parse_directory(pdf_dir)
45
  indexed = 0
46
  for doc in docs:
47
  if doc.full_text and not doc.error:
48
+ indexing_svc.index_text(doc.full_text, title=doc.filename, source_file=doc.filename)
49
  indexed += 1
50
 
51
  print(f"Ingested {indexed}/{len(docs)} documents")
alembic/env.py CHANGED
@@ -1,25 +1,23 @@
1
- from logging.config import fileConfig
2
-
3
- from sqlalchemy import engine_from_config
4
- from sqlalchemy import pool, create_engine
5
-
6
- from alembic import context
7
 
8
  # ---------------------------------------------------------------------------
9
  # MediGuard AI — Alembic env.py
10
  # Pull DB URL from settings so we never hard-code credentials.
11
  # ---------------------------------------------------------------------------
12
  import sys
13
- import os
 
 
 
 
14
 
15
  # Make sure the project root is on sys.path
16
  sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
17
 
18
- from src.settings import get_settings # noqa: E402
19
- from src.database import Base # noqa: E402
20
-
21
  # Import all models so Alembic's autogenerate can see them
22
- import src.models.analysis # noqa: F401, E402
 
 
23
 
24
  # this is the Alembic Config object, which provides
25
  # access to the values within the .ini file in use.
 
1
+ import os
 
 
 
 
 
2
 
3
  # ---------------------------------------------------------------------------
4
  # MediGuard AI — Alembic env.py
5
  # Pull DB URL from settings so we never hard-code credentials.
6
  # ---------------------------------------------------------------------------
7
  import sys
8
+ from logging.config import fileConfig
9
+
10
+ from sqlalchemy import engine_from_config, pool
11
+
12
+ from alembic import context
13
 
14
  # Make sure the project root is on sys.path
15
  sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
16
 
 
 
 
17
  # Import all models so Alembic's autogenerate can see them
18
+ import src.models.analysis # noqa: F401
19
+ from src.database import Base
20
+ from src.settings import get_settings
21
 
22
  # this is the Alembic Config object, which provides
23
  # access to the values within the .ini file in use.
alembic/versions/001_initial.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """initial_tables
2
+
3
+ Revision ID: 001
4
+ Revises:
5
+ Create Date: 2026-02-24 20:58:00.000000
6
+
7
+ """
8
+ import sqlalchemy as sa
9
+
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = '001'
14
+ down_revision = None
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ op.create_table(
21
+ 'patient_analyses',
22
+ sa.Column('id', sa.String(length=36), nullable=False),
23
+ sa.Column('request_id', sa.String(length=64), nullable=False),
24
+ sa.Column('biomarkers', sa.JSON(), nullable=False),
25
+ sa.Column('patient_context', sa.JSON(), nullable=True),
26
+ sa.Column('predicted_disease', sa.String(length=128), nullable=False),
27
+ sa.Column('confidence', sa.Float(), nullable=False),
28
+ sa.Column('probabilities', sa.JSON(), nullable=True),
29
+ sa.Column('analysis_result', sa.JSON(), nullable=True),
30
+ sa.Column('safety_alerts', sa.JSON(), nullable=True),
31
+ sa.Column('sop_version', sa.String(length=64), nullable=True),
32
+ sa.Column('processing_time_ms', sa.Float(), nullable=False),
33
+ sa.Column('model_provider', sa.String(length=32), nullable=True),
34
+ sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
35
+ sa.PrimaryKeyConstraint('id')
36
+ )
37
+ op.create_index(op.f('ix_patient_analyses_request_id'), 'patient_analyses', ['request_id'], unique=True)
38
+
39
+ op.create_table(
40
+ 'medical_documents',
41
+ sa.Column('id', sa.String(length=36), nullable=False),
42
+ sa.Column('title', sa.String(length=512), nullable=False),
43
+ sa.Column('source', sa.String(length=512), nullable=False),
44
+ sa.Column('source_type', sa.String(length=32), nullable=False),
45
+ sa.Column('authors', sa.Text(), nullable=True),
46
+ sa.Column('abstract', sa.Text(), nullable=True),
47
+ sa.Column('content_hash', sa.String(length=64), nullable=True),
48
+ sa.Column('page_count', sa.Integer(), nullable=True),
49
+ sa.Column('chunk_count', sa.Integer(), nullable=True),
50
+ sa.Column('parse_status', sa.String(length=32), nullable=False),
51
+ sa.Column('metadata_json', sa.JSON(), nullable=True),
52
+ sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
53
+ sa.Column('indexed_at', sa.DateTime(timezone=True), nullable=True),
54
+ sa.PrimaryKeyConstraint('id'),
55
+ sa.UniqueConstraint('content_hash')
56
+ )
57
+ op.create_index(op.f('ix_medical_documents_title'), 'medical_documents', ['title'], unique=False)
58
+
59
+ op.create_table(
60
+ 'sop_versions',
61
+ sa.Column('id', sa.String(length=36), nullable=False),
62
+ sa.Column('version_tag', sa.String(length=64), nullable=False),
63
+ sa.Column('parameters', sa.JSON(), nullable=False),
64
+ sa.Column('evaluation_scores', sa.JSON(), nullable=True),
65
+ sa.Column('parent_version', sa.String(length=64), nullable=True),
66
+ sa.Column('is_active', sa.Boolean(), nullable=False),
67
+ sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
68
+ sa.PrimaryKeyConstraint('id')
69
+ )
70
+ op.create_index(op.f('ix_sop_versions_version_tag'), 'sop_versions', ['version_tag'], unique=True)
71
+
72
+
73
+ def downgrade() -> None:
74
+ op.drop_index(op.f('ix_sop_versions_version_tag'), table_name='sop_versions')
75
+ op.drop_table('sop_versions')
76
+
77
+ op.drop_index(op.f('ix_medical_documents_title'), table_name='medical_documents')
78
+ op.drop_table('medical_documents')
79
+
80
+ op.drop_index(op.f('ix_patient_analyses_request_id'), table_name='patient_analyses')
81
+ op.drop_table('patient_analyses')
api/app/main.py CHANGED
@@ -3,22 +3,19 @@ RagBot FastAPI Main Application
3
  Medical biomarker analysis API
4
  """
5
 
6
- import os
7
- import sys
8
  import logging
9
- from pathlib import Path
10
  from contextlib import asynccontextmanager
11
 
12
  from fastapi import FastAPI, Request, status
 
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse
15
- from fastapi.exceptions import RequestValidationError
16
 
17
  from app import __version__
18
- from app.routes import health, biomarkers, analyze
19
  from app.services.ragbot import get_ragbot_service
20
 
21
-
22
  # Configure logging
23
  logging.basicConfig(
24
  level=logging.INFO,
@@ -40,7 +37,7 @@ async def lifespan(app: FastAPI):
40
  logger.info("=" * 70)
41
  logger.info("Starting RagBot API Server")
42
  logger.info("=" * 70)
43
-
44
  # Startup: Initialize RagBot service
45
  try:
46
  ragbot_service = get_ragbot_service()
@@ -49,12 +46,12 @@ async def lifespan(app: FastAPI):
49
  except Exception as e:
50
  logger.error(f"Failed to initialize RagBot service: {e}")
51
  logger.warning("API will start but health checks will fail")
52
-
53
  logger.info("API server ready to accept requests")
54
  logger.info("=" * 70)
55
-
56
  yield # Server runs here
57
-
58
  # Shutdown
59
  logger.info("Shutting down RagBot API Server")
60
 
@@ -178,14 +175,14 @@ async def api_v1_info():
178
 
179
  if __name__ == "__main__":
180
  import uvicorn
181
-
182
  # Get configuration from environment
183
  host = os.getenv("API_HOST", "0.0.0.0")
184
  port = int(os.getenv("API_PORT", "8000"))
185
  reload = os.getenv("API_RELOAD", "false").lower() == "true"
186
-
187
  logger.info(f"Starting server on {host}:{port}")
188
-
189
  uvicorn.run(
190
  "app.main:app",
191
  host=host,
 
3
  Medical biomarker analysis API
4
  """
5
 
 
 
6
  import logging
7
+ import os
8
  from contextlib import asynccontextmanager
9
 
10
  from fastapi import FastAPI, Request, status
11
+ from fastapi.exceptions import RequestValidationError
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from fastapi.responses import JSONResponse
 
14
 
15
  from app import __version__
16
+ from app.routes import analyze, biomarkers, health
17
  from app.services.ragbot import get_ragbot_service
18
 
 
19
  # Configure logging
20
  logging.basicConfig(
21
  level=logging.INFO,
 
37
  logger.info("=" * 70)
38
  logger.info("Starting RagBot API Server")
39
  logger.info("=" * 70)
40
+
41
  # Startup: Initialize RagBot service
42
  try:
43
  ragbot_service = get_ragbot_service()
 
46
  except Exception as e:
47
  logger.error(f"Failed to initialize RagBot service: {e}")
48
  logger.warning("API will start but health checks will fail")
49
+
50
  logger.info("API server ready to accept requests")
51
  logger.info("=" * 70)
52
+
53
  yield # Server runs here
54
+
55
  # Shutdown
56
  logger.info("Shutting down RagBot API Server")
57
 
 
175
 
176
  if __name__ == "__main__":
177
  import uvicorn
178
+
179
  # Get configuration from environment
180
  host = os.getenv("API_HOST", "0.0.0.0")
181
  port = int(os.getenv("API_PORT", "8000"))
182
  reload = os.getenv("API_RELOAD", "false").lower() == "true"
183
+
184
  logger.info(f"Starting server on {host}:{port}")
185
+
186
  uvicorn.run(
187
  "app.main:app",
188
  host=host,
api/app/routes/analyze.py CHANGED
@@ -4,19 +4,13 @@ Natural language and structured biomarker analysis
4
  """
5
 
6
  import os
7
- from datetime import datetime
8
  from fastapi import APIRouter, HTTPException, status
9
 
10
- from app.models.schemas import (
11
- NaturalAnalysisRequest,
12
- StructuredAnalysisRequest,
13
- AnalysisResponse,
14
- ErrorResponse
15
- )
16
  from app.services.extraction import extract_biomarkers, predict_disease_simple
17
  from app.services.ragbot import get_ragbot_service
18
 
19
-
20
  router = APIRouter(prefix="/api/v1", tags=["analysis"])
21
 
22
 
@@ -45,23 +39,23 @@ async def analyze_natural(request: NaturalAnalysisRequest):
45
 
46
  Returns full detailed analysis with all agent outputs, citations, recommendations.
47
  """
48
-
49
  # Get services
50
  ragbot_service = get_ragbot_service()
51
-
52
  if not ragbot_service.is_ready():
53
  raise HTTPException(
54
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
55
  detail="RagBot service not initialized. Please try again in a moment."
56
  )
57
-
58
  # Extract biomarkers from natural language
59
  ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
60
  biomarkers, extracted_context, error = extract_biomarkers(
61
  request.message,
62
  ollama_base_url=ollama_base_url
63
  )
64
-
65
  if error:
66
  raise HTTPException(
67
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -72,7 +66,7 @@ async def analyze_natural(request: NaturalAnalysisRequest):
72
  "suggestion": "Try: 'My glucose is 140 and HbA1c is 7.5'"
73
  }
74
  )
75
-
76
  if not biomarkers:
77
  raise HTTPException(
78
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -83,14 +77,14 @@ async def analyze_natural(request: NaturalAnalysisRequest):
83
  "suggestion": "Include specific biomarker values like 'glucose is 140'"
84
  }
85
  )
86
-
87
  # Merge extracted context with request context
88
  patient_context = request.patient_context.model_dump() if request.patient_context else {}
89
  patient_context.update(extracted_context)
90
-
91
  # Predict disease (simple rule-based for now)
92
  model_prediction = predict_disease_simple(biomarkers)
93
-
94
  try:
95
  # Run full analysis
96
  response = ragbot_service.analyze(
@@ -99,15 +93,15 @@ async def analyze_natural(request: NaturalAnalysisRequest):
99
  model_prediction=model_prediction,
100
  extracted_biomarkers=biomarkers # Keep original extraction
101
  )
102
-
103
  return response
104
-
105
  except Exception as e:
106
  raise HTTPException(
107
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
108
  detail={
109
  "error_code": "ANALYSIS_FAILED",
110
- "message": f"Analysis workflow failed: {str(e)}",
111
  "biomarkers_received": biomarkers
112
  }
113
  )
@@ -145,16 +139,16 @@ async def analyze_structured(request: StructuredAnalysisRequest):
145
  Use this endpoint when you already have structured biomarker data.
146
  Returns full detailed analysis with all agent outputs, citations, recommendations.
147
  """
148
-
149
  # Get services
150
  ragbot_service = get_ragbot_service()
151
-
152
  if not ragbot_service.is_ready():
153
  raise HTTPException(
154
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
155
  detail="RagBot service not initialized. Please try again in a moment."
156
  )
157
-
158
  # Validate biomarkers
159
  if not request.biomarkers:
160
  raise HTTPException(
@@ -165,13 +159,13 @@ async def analyze_structured(request: StructuredAnalysisRequest):
165
  "suggestion": "Provide at least one biomarker with a numeric value"
166
  }
167
  )
168
-
169
  # Patient context
170
  patient_context = request.patient_context.model_dump() if request.patient_context else {}
171
-
172
  # Predict disease
173
  model_prediction = predict_disease_simple(request.biomarkers)
174
-
175
  try:
176
  # Run full analysis
177
  response = ragbot_service.analyze(
@@ -180,15 +174,15 @@ async def analyze_structured(request: StructuredAnalysisRequest):
180
  model_prediction=model_prediction,
181
  extracted_biomarkers=None # No extraction for structured input
182
  )
183
-
184
  return response
185
-
186
  except Exception as e:
187
  raise HTTPException(
188
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
189
  detail={
190
  "error_code": "ANALYSIS_FAILED",
191
- "message": f"Analysis workflow failed: {str(e)}",
192
  "biomarkers_received": request.biomarkers
193
  }
194
  )
@@ -211,16 +205,16 @@ async def get_example():
211
 
212
  Same as CLI chatbot 'example' command.
213
  """
214
-
215
  # Get services
216
  ragbot_service = get_ragbot_service()
217
-
218
  if not ragbot_service.is_ready():
219
  raise HTTPException(
220
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
221
  detail="RagBot service not initialized. Please try again in a moment."
222
  )
223
-
224
  # Example biomarkers (Type 2 Diabetes patient)
225
  biomarkers = {
226
  "Glucose": 185.0,
@@ -235,14 +229,14 @@ async def get_example():
235
  "Systolic Blood Pressure": 142.0,
236
  "Diastolic Blood Pressure": 88.0
237
  }
238
-
239
  patient_context = {
240
  "age": 52,
241
  "gender": "male",
242
  "bmi": 31.2,
243
  "patient_id": "EXAMPLE-001"
244
  }
245
-
246
  model_prediction = {
247
  "disease": "Diabetes",
248
  "confidence": 0.87,
@@ -254,7 +248,7 @@ async def get_example():
254
  "Thrombocytopenia": 0.01
255
  }
256
  }
257
-
258
  try:
259
  # Run analysis
260
  response = ragbot_service.analyze(
@@ -263,14 +257,14 @@ async def get_example():
263
  model_prediction=model_prediction,
264
  extracted_biomarkers=None
265
  )
266
-
267
  return response
268
-
269
  except Exception as e:
270
  raise HTTPException(
271
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
272
  detail={
273
  "error_code": "EXAMPLE_FAILED",
274
- "message": f"Example analysis failed: {str(e)}"
275
  }
276
  )
 
4
  """
5
 
6
  import os
7
+
8
  from fastapi import APIRouter, HTTPException, status
9
 
10
+ from app.models.schemas import AnalysisResponse, NaturalAnalysisRequest, StructuredAnalysisRequest
 
 
 
 
 
11
  from app.services.extraction import extract_biomarkers, predict_disease_simple
12
  from app.services.ragbot import get_ragbot_service
13
 
 
14
  router = APIRouter(prefix="/api/v1", tags=["analysis"])
15
 
16
 
 
39
 
40
  Returns full detailed analysis with all agent outputs, citations, recommendations.
41
  """
42
+
43
  # Get services
44
  ragbot_service = get_ragbot_service()
45
+
46
  if not ragbot_service.is_ready():
47
  raise HTTPException(
48
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
49
  detail="RagBot service not initialized. Please try again in a moment."
50
  )
51
+
52
  # Extract biomarkers from natural language
53
  ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
54
  biomarkers, extracted_context, error = extract_biomarkers(
55
  request.message,
56
  ollama_base_url=ollama_base_url
57
  )
58
+
59
  if error:
60
  raise HTTPException(
61
  status_code=status.HTTP_400_BAD_REQUEST,
 
66
  "suggestion": "Try: 'My glucose is 140 and HbA1c is 7.5'"
67
  }
68
  )
69
+
70
  if not biomarkers:
71
  raise HTTPException(
72
  status_code=status.HTTP_400_BAD_REQUEST,
 
77
  "suggestion": "Include specific biomarker values like 'glucose is 140'"
78
  }
79
  )
80
+
81
  # Merge extracted context with request context
82
  patient_context = request.patient_context.model_dump() if request.patient_context else {}
83
  patient_context.update(extracted_context)
84
+
85
  # Predict disease (simple rule-based for now)
86
  model_prediction = predict_disease_simple(biomarkers)
87
+
88
  try:
89
  # Run full analysis
90
  response = ragbot_service.analyze(
 
93
  model_prediction=model_prediction,
94
  extracted_biomarkers=biomarkers # Keep original extraction
95
  )
96
+
97
  return response
98
+
99
  except Exception as e:
100
  raise HTTPException(
101
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
102
  detail={
103
  "error_code": "ANALYSIS_FAILED",
104
+ "message": f"Analysis workflow failed: {e!s}",
105
  "biomarkers_received": biomarkers
106
  }
107
  )
 
139
  Use this endpoint when you already have structured biomarker data.
140
  Returns full detailed analysis with all agent outputs, citations, recommendations.
141
  """
142
+
143
  # Get services
144
  ragbot_service = get_ragbot_service()
145
+
146
  if not ragbot_service.is_ready():
147
  raise HTTPException(
148
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
149
  detail="RagBot service not initialized. Please try again in a moment."
150
  )
151
+
152
  # Validate biomarkers
153
  if not request.biomarkers:
154
  raise HTTPException(
 
159
  "suggestion": "Provide at least one biomarker with a numeric value"
160
  }
161
  )
162
+
163
  # Patient context
164
  patient_context = request.patient_context.model_dump() if request.patient_context else {}
165
+
166
  # Predict disease
167
  model_prediction = predict_disease_simple(request.biomarkers)
168
+
169
  try:
170
  # Run full analysis
171
  response = ragbot_service.analyze(
 
174
  model_prediction=model_prediction,
175
  extracted_biomarkers=None # No extraction for structured input
176
  )
177
+
178
  return response
179
+
180
  except Exception as e:
181
  raise HTTPException(
182
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
183
  detail={
184
  "error_code": "ANALYSIS_FAILED",
185
+ "message": f"Analysis workflow failed: {e!s}",
186
  "biomarkers_received": request.biomarkers
187
  }
188
  )
 
205
 
206
  Same as CLI chatbot 'example' command.
207
  """
208
+
209
  # Get services
210
  ragbot_service = get_ragbot_service()
211
+
212
  if not ragbot_service.is_ready():
213
  raise HTTPException(
214
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
215
  detail="RagBot service not initialized. Please try again in a moment."
216
  )
217
+
218
  # Example biomarkers (Type 2 Diabetes patient)
219
  biomarkers = {
220
  "Glucose": 185.0,
 
229
  "Systolic Blood Pressure": 142.0,
230
  "Diastolic Blood Pressure": 88.0
231
  }
232
+
233
  patient_context = {
234
  "age": 52,
235
  "gender": "male",
236
  "bmi": 31.2,
237
  "patient_id": "EXAMPLE-001"
238
  }
239
+
240
  model_prediction = {
241
  "disease": "Diabetes",
242
  "confidence": 0.87,
 
248
  "Thrombocytopenia": 0.01
249
  }
250
  }
251
+
252
  try:
253
  # Run analysis
254
  response = ragbot_service.analyze(
 
257
  model_prediction=model_prediction,
258
  extracted_biomarkers=None
259
  )
260
+
261
  return response
262
+
263
  except Exception as e:
264
  raise HTTPException(
265
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
266
  detail={
267
  "error_code": "EXAMPLE_FAILED",
268
+ "message": f"Example analysis failed: {e!s}"
269
  }
270
  )
api/app/routes/biomarkers.py CHANGED
@@ -3,13 +3,12 @@ Biomarkers List Endpoint
3
  """
4
 
5
  import json
6
- import sys
7
- from pathlib import Path
8
  from datetime import datetime
9
- from fastapi import APIRouter, HTTPException
10
 
11
- from app.models.schemas import BiomarkersListResponse, BiomarkerInfo, BiomarkerReferenceRange
12
 
 
13
 
14
  router = APIRouter(prefix="/api/v1", tags=["biomarkers"])
15
 
@@ -30,22 +29,22 @@ async def list_biomarkers():
30
  - Understanding what biomarkers can be analyzed
31
  - Getting reference ranges for display
32
  """
33
-
34
  try:
35
  # Load biomarker references
36
  config_path = Path(__file__).parent.parent.parent.parent / "config" / "biomarker_references.json"
37
-
38
- with open(config_path, 'r') as f:
39
  config_data = json.load(f)
40
-
41
  biomarkers_data = config_data.get("biomarkers", {})
42
-
43
  biomarkers_list = []
44
-
45
  for name, info in biomarkers_data.items():
46
  # Parse reference range
47
  normal_range_data = info.get("normal_range", {})
48
-
49
  if "male" in normal_range_data or "female" in normal_range_data:
50
  # Gender-specific ranges
51
  reference_range = BiomarkerReferenceRange(
@@ -62,7 +61,7 @@ async def list_biomarkers():
62
  male=None,
63
  female=None
64
  )
65
-
66
  biomarker_info = BiomarkerInfo(
67
  name=name,
68
  unit=info.get("unit", ""),
@@ -73,23 +72,23 @@ async def list_biomarkers():
73
  description=info.get("description", ""),
74
  clinical_significance=info.get("clinical_significance", {})
75
  )
76
-
77
  biomarkers_list.append(biomarker_info)
78
-
79
  return BiomarkersListResponse(
80
  biomarkers=biomarkers_list,
81
  total_count=len(biomarkers_list),
82
  timestamp=datetime.now().isoformat()
83
  )
84
-
85
  except FileNotFoundError:
86
  raise HTTPException(
87
  status_code=500,
88
  detail="Biomarker configuration file not found"
89
  )
90
-
91
  except Exception as e:
92
  raise HTTPException(
93
  status_code=500,
94
- detail=f"Failed to load biomarkers: {str(e)}"
95
  )
 
3
  """
4
 
5
  import json
 
 
6
  from datetime import datetime
7
+ from pathlib import Path
8
 
9
+ from fastapi import APIRouter, HTTPException
10
 
11
+ from app.models.schemas import BiomarkerInfo, BiomarkerReferenceRange, BiomarkersListResponse
12
 
13
  router = APIRouter(prefix="/api/v1", tags=["biomarkers"])
14
 
 
29
  - Understanding what biomarkers can be analyzed
30
  - Getting reference ranges for display
31
  """
32
+
33
  try:
34
  # Load biomarker references
35
  config_path = Path(__file__).parent.parent.parent.parent / "config" / "biomarker_references.json"
36
+
37
+ with open(config_path) as f:
38
  config_data = json.load(f)
39
+
40
  biomarkers_data = config_data.get("biomarkers", {})
41
+
42
  biomarkers_list = []
43
+
44
  for name, info in biomarkers_data.items():
45
  # Parse reference range
46
  normal_range_data = info.get("normal_range", {})
47
+
48
  if "male" in normal_range_data or "female" in normal_range_data:
49
  # Gender-specific ranges
50
  reference_range = BiomarkerReferenceRange(
 
61
  male=None,
62
  female=None
63
  )
64
+
65
  biomarker_info = BiomarkerInfo(
66
  name=name,
67
  unit=info.get("unit", ""),
 
72
  description=info.get("description", ""),
73
  clinical_significance=info.get("clinical_significance", {})
74
  )
75
+
76
  biomarkers_list.append(biomarker_info)
77
+
78
  return BiomarkersListResponse(
79
  biomarkers=biomarkers_list,
80
  total_count=len(biomarkers_list),
81
  timestamp=datetime.now().isoformat()
82
  )
83
+
84
  except FileNotFoundError:
85
  raise HTTPException(
86
  status_code=500,
87
  detail="Biomarker configuration file not found"
88
  )
89
+
90
  except Exception as e:
91
  raise HTTPException(
92
  status_code=500,
93
+ detail=f"Failed to load biomarkers: {e!s}"
94
  )
api/app/routes/health.py CHANGED
@@ -2,16 +2,13 @@
2
  Health Check Endpoint
3
  """
4
 
5
- import os
6
- import sys
7
- from pathlib import Path
8
  from datetime import datetime
9
- from fastapi import APIRouter, HTTPException
10
 
 
 
 
11
  from app.models.schemas import HealthResponse
12
  from app.services.ragbot import get_ragbot_service
13
- from app import __version__
14
-
15
 
16
  router = APIRouter(prefix="/api/v1", tags=["health"])
17
 
@@ -30,16 +27,16 @@ async def health_check():
30
  Returns health status with component details.
31
  """
32
  ragbot_service = get_ragbot_service()
33
-
34
  # Check LLM API connection
35
  llm_status = "disconnected"
36
  available_models = []
37
-
38
  try:
39
- from src.llm_config import get_chat_model, DEFAULT_LLM_PROVIDER
40
-
41
  test_llm = get_chat_model(temperature=0.0)
42
-
43
  # Try a simple test
44
  response = test_llm.invoke("Say OK")
45
  if response:
@@ -50,13 +47,13 @@ async def health_check():
50
  available_models = ["gemini-2.0-flash (Google)"]
51
  else:
52
  available_models = ["llama3.1:8b (Ollama)"]
53
-
54
  except Exception as e:
55
  llm_status = f"error: {str(e)[:100]}"
56
-
57
  # Check vector store
58
  vector_store_loaded = ragbot_service.is_ready()
59
-
60
  # Determine overall status
61
  if llm_status == "connected" and vector_store_loaded:
62
  overall_status = "healthy"
@@ -64,7 +61,7 @@ async def health_check():
64
  overall_status = "degraded"
65
  else:
66
  overall_status = "unhealthy"
67
-
68
  return HealthResponse(
69
  status=overall_status,
70
  timestamp=datetime.now().isoformat(),
 
2
  Health Check Endpoint
3
  """
4
 
 
 
 
5
  from datetime import datetime
 
6
 
7
+ from fastapi import APIRouter
8
+
9
+ from app import __version__
10
  from app.models.schemas import HealthResponse
11
  from app.services.ragbot import get_ragbot_service
 
 
12
 
13
  router = APIRouter(prefix="/api/v1", tags=["health"])
14
 
 
27
  Returns health status with component details.
28
  """
29
  ragbot_service = get_ragbot_service()
30
+
31
  # Check LLM API connection
32
  llm_status = "disconnected"
33
  available_models = []
34
+
35
  try:
36
+ from src.llm_config import DEFAULT_LLM_PROVIDER, get_chat_model
37
+
38
  test_llm = get_chat_model(temperature=0.0)
39
+
40
  # Try a simple test
41
  response = test_llm.invoke("Say OK")
42
  if response:
 
47
  available_models = ["gemini-2.0-flash (Google)"]
48
  else:
49
  available_models = ["llama3.1:8b (Ollama)"]
50
+
51
  except Exception as e:
52
  llm_status = f"error: {str(e)[:100]}"
53
+
54
  # Check vector store
55
  vector_store_loaded = ragbot_service.is_ready()
56
+
57
  # Determine overall status
58
  if llm_status == "connected" and vector_store_loaded:
59
  overall_status = "healthy"
 
61
  overall_status = "degraded"
62
  else:
63
  overall_status = "unhealthy"
64
+
65
  return HealthResponse(
66
  status=overall_status,
67
  timestamp=datetime.now().isoformat(),
api/app/services/extraction.py CHANGED
@@ -6,7 +6,7 @@ Extracts biomarker values from natural language text using LLM
6
  import json
7
  import sys
8
  from pathlib import Path
9
- from typing import Dict, Any, Tuple
10
 
11
  # Ensure project root is in path for src imports
12
  _project_root = str(Path(__file__).parent.parent.parent.parent)
@@ -14,10 +14,10 @@ if _project_root not in sys.path:
14
  sys.path.insert(0, _project_root)
15
 
16
  from langchain_core.prompts import ChatPromptTemplate
 
17
  from src.biomarker_normalization import normalize_biomarker_name
18
  from src.llm_config import get_chat_model
19
 
20
-
21
  # ============================================================================
22
  # EXTRACTION PROMPT
23
  # ============================================================================
@@ -54,7 +54,7 @@ If you cannot find any biomarkers, return {{"biomarkers": {{}}, "patient_context
54
  # EXTRACTION HELPERS
55
  # ============================================================================
56
 
57
- def _parse_llm_json(content: str) -> Dict[str, Any]:
58
  """Parse JSON payload from LLM output with fallback recovery."""
59
  text = content.strip()
60
 
@@ -78,9 +78,9 @@ def _parse_llm_json(content: str) -> Dict[str, Any]:
78
  # ============================================================================
79
 
80
  def extract_biomarkers(
81
- user_message: str,
82
  ollama_base_url: str = None # Kept for backward compatibility, ignored
83
- ) -> Tuple[Dict[str, float], Dict[str, Any], str]:
84
  """
85
  Extract biomarker values from natural language using LLM.
86
 
@@ -102,18 +102,18 @@ def extract_biomarkers(
102
  try:
103
  # Initialize LLM (uses Groq/Gemini by default - FREE)
104
  llm = get_chat_model(temperature=0.0)
105
-
106
  prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
107
  chain = prompt | llm
108
-
109
  # Invoke LLM
110
  response = chain.invoke({"user_message": user_message})
111
  content = response.content.strip()
112
-
113
  extracted = _parse_llm_json(content)
114
  biomarkers = extracted.get("biomarkers", {})
115
  patient_context = extracted.get("patient_context", {})
116
-
117
  # Normalize biomarker names and convert to float
118
  normalized = {}
119
  for key, value in biomarkers.items():
@@ -123,27 +123,27 @@ def extract_biomarkers(
123
  except (ValueError, TypeError):
124
  # Skip invalid values
125
  continue
126
-
127
  # Clean up patient context (remove null values)
128
  patient_context = {k: v for k, v in patient_context.items() if v is not None}
129
-
130
  if not normalized:
131
  return {}, patient_context, "No biomarkers found in the input"
132
-
133
  return normalized, patient_context, ""
134
-
135
  except json.JSONDecodeError as e:
136
- return {}, {}, f"Failed to parse LLM response as JSON: {str(e)}"
137
-
138
  except Exception as e:
139
- return {}, {}, f"Extraction failed: {str(e)}"
140
 
141
 
142
  # ============================================================================
143
  # SIMPLE DISEASE PREDICTION (Fallback)
144
  # ============================================================================
145
 
146
- def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
147
  """
148
  Simple rule-based disease prediction based on key biomarkers.
149
  Used as a fallback when no ML model is available.
@@ -161,15 +161,15 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
161
  "Thrombocytopenia": 0.0,
162
  "Thalassemia": 0.0
163
  }
164
-
165
  # Helper: check both abbreviated and normalized biomarker names
166
  # Returns None when biomarker is not present (avoids false triggers)
167
  def _get(name, *alt_names):
168
- val = biomarkers.get(name, None)
169
  if val is not None:
170
  return val
171
  for alt in alt_names:
172
- val = biomarkers.get(alt, None)
173
  if val is not None:
174
  return val
175
  return None
@@ -183,7 +183,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
183
  scores["Diabetes"] += 0.2
184
  if hba1c is not None and hba1c >= 6.5:
185
  scores["Diabetes"] += 0.5
186
-
187
  # Anemia indicators
188
  hemoglobin = _get("Hemoglobin")
189
  mcv = _get("Mean Corpuscular Volume", "MCV")
@@ -193,7 +193,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
193
  scores["Anemia"] += 0.2
194
  if mcv is not None and mcv < 80:
195
  scores["Anemia"] += 0.2
196
-
197
  # Heart disease indicators
198
  cholesterol = _get("Cholesterol")
199
  troponin = _get("Troponin")
@@ -204,32 +204,32 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
204
  scores["Heart Disease"] += 0.6
205
  if ldl is not None and ldl > 190:
206
  scores["Heart Disease"] += 0.2
207
-
208
  # Thrombocytopenia indicators
209
  platelets = _get("Platelets")
210
  if platelets is not None and platelets < 150000:
211
  scores["Thrombocytopenia"] += 0.6
212
  if platelets is not None and platelets < 50000:
213
  scores["Thrombocytopenia"] += 0.3
214
-
215
  # Thalassemia indicators (simplified)
216
  if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
217
  scores["Thalassemia"] += 0.4
218
-
219
  # Find top prediction
220
  top_disease = max(scores, key=scores.get)
221
  confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
222
 
223
  if confidence == 0.0:
224
  top_disease = "Undetermined"
225
-
226
  # Normalize probabilities to sum to 1.0
227
  total = sum(scores.values())
228
  if total > 0:
229
  probabilities = {k: v / total for k, v in scores.items()}
230
  else:
231
  probabilities = {k: 1.0 / len(scores) for k in scores}
232
-
233
  return {
234
  "disease": top_disease,
235
  "confidence": confidence,
 
6
  import json
7
  import sys
8
  from pathlib import Path
9
+ from typing import Any
10
 
11
  # Ensure project root is in path for src imports
12
  _project_root = str(Path(__file__).parent.parent.parent.parent)
 
14
  sys.path.insert(0, _project_root)
15
 
16
  from langchain_core.prompts import ChatPromptTemplate
17
+
18
  from src.biomarker_normalization import normalize_biomarker_name
19
  from src.llm_config import get_chat_model
20
 
 
21
  # ============================================================================
22
  # EXTRACTION PROMPT
23
  # ============================================================================
 
54
  # EXTRACTION HELPERS
55
  # ============================================================================
56
 
57
+ def _parse_llm_json(content: str) -> dict[str, Any]:
58
  """Parse JSON payload from LLM output with fallback recovery."""
59
  text = content.strip()
60
 
 
78
  # ============================================================================
79
 
80
  def extract_biomarkers(
81
+ user_message: str,
82
  ollama_base_url: str = None # Kept for backward compatibility, ignored
83
+ ) -> tuple[dict[str, float], dict[str, Any], str]:
84
  """
85
  Extract biomarker values from natural language using LLM.
86
 
 
102
  try:
103
  # Initialize LLM (uses Groq/Gemini by default - FREE)
104
  llm = get_chat_model(temperature=0.0)
105
+
106
  prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
107
  chain = prompt | llm
108
+
109
  # Invoke LLM
110
  response = chain.invoke({"user_message": user_message})
111
  content = response.content.strip()
112
+
113
  extracted = _parse_llm_json(content)
114
  biomarkers = extracted.get("biomarkers", {})
115
  patient_context = extracted.get("patient_context", {})
116
+
117
  # Normalize biomarker names and convert to float
118
  normalized = {}
119
  for key, value in biomarkers.items():
 
123
  except (ValueError, TypeError):
124
  # Skip invalid values
125
  continue
126
+
127
  # Clean up patient context (remove null values)
128
  patient_context = {k: v for k, v in patient_context.items() if v is not None}
129
+
130
  if not normalized:
131
  return {}, patient_context, "No biomarkers found in the input"
132
+
133
  return normalized, patient_context, ""
134
+
135
  except json.JSONDecodeError as e:
136
+ return {}, {}, f"Failed to parse LLM response as JSON: {e!s}"
137
+
138
  except Exception as e:
139
+ return {}, {}, f"Extraction failed: {e!s}"
140
 
141
 
142
  # ============================================================================
143
  # SIMPLE DISEASE PREDICTION (Fallback)
144
  # ============================================================================
145
 
146
+ def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
147
  """
148
  Simple rule-based disease prediction based on key biomarkers.
149
  Used as a fallback when no ML model is available.
 
161
  "Thrombocytopenia": 0.0,
162
  "Thalassemia": 0.0
163
  }
164
+
165
  # Helper: check both abbreviated and normalized biomarker names
166
  # Returns None when biomarker is not present (avoids false triggers)
167
  def _get(name, *alt_names):
168
+ val = biomarkers.get(name)
169
  if val is not None:
170
  return val
171
  for alt in alt_names:
172
+ val = biomarkers.get(alt)
173
  if val is not None:
174
  return val
175
  return None
 
183
  scores["Diabetes"] += 0.2
184
  if hba1c is not None and hba1c >= 6.5:
185
  scores["Diabetes"] += 0.5
186
+
187
  # Anemia indicators
188
  hemoglobin = _get("Hemoglobin")
189
  mcv = _get("Mean Corpuscular Volume", "MCV")
 
193
  scores["Anemia"] += 0.2
194
  if mcv is not None and mcv < 80:
195
  scores["Anemia"] += 0.2
196
+
197
  # Heart disease indicators
198
  cholesterol = _get("Cholesterol")
199
  troponin = _get("Troponin")
 
204
  scores["Heart Disease"] += 0.6
205
  if ldl is not None and ldl > 190:
206
  scores["Heart Disease"] += 0.2
207
+
208
  # Thrombocytopenia indicators
209
  platelets = _get("Platelets")
210
  if platelets is not None and platelets < 150000:
211
  scores["Thrombocytopenia"] += 0.6
212
  if platelets is not None and platelets < 50000:
213
  scores["Thrombocytopenia"] += 0.3
214
+
215
  # Thalassemia indicators (simplified)
216
  if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
217
  scores["Thalassemia"] += 0.4
218
+
219
  # Find top prediction
220
  top_disease = max(scores, key=scores.get)
221
  confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
222
 
223
  if confidence == 0.0:
224
  top_disease = "Undetermined"
225
+
226
  # Normalize probabilities to sum to 1.0
227
  total = sum(scores.values())
228
  if total > 0:
229
  probabilities = {k: v / total for k, v in scores.items()}
230
  else:
231
  probabilities = {k: 1.0 / len(scores) for k in scores}
232
+
233
  return {
234
  "disease": top_disease,
235
  "confidence": confidence,
api/app/services/ragbot.py CHANGED
@@ -6,22 +6,29 @@ Wraps the RagBot workflow and formats comprehensive responses
6
  import sys
7
  import time
8
  import uuid
9
- from pathlib import Path
10
- from typing import Dict, Any
11
  from datetime import datetime
 
 
12
 
13
  # Ensure project root is in path for src imports
14
  _project_root = str(Path(__file__).parent.parent.parent.parent)
15
  if _project_root not in sys.path:
16
  sys.path.insert(0, _project_root)
17
 
18
- from src.workflow import create_guild
19
- from src.state import PatientInput
20
  from app.models.schemas import (
21
- AnalysisResponse, Analysis, Prediction, BiomarkerFlag,
22
- SafetyAlert, KeyDriver, DiseaseExplanation, Recommendations,
23
- ConfidenceAssessment, AgentOutput
 
 
 
 
 
 
 
24
  )
 
 
25
 
26
 
27
  class RagBotService:
@@ -29,65 +36,65 @@ class RagBotService:
29
  Service class to manage RagBot workflow lifecycle.
30
  Initializes once, then handles multiple analysis requests.
31
  """
32
-
33
  def __init__(self):
34
  """Initialize the workflow (loads vector store, models, etc.)"""
35
  self.guild = None
36
  self.initialized = False
37
  self.init_time = None
38
-
39
  def initialize(self):
40
  """Initialize the Clinical Insight Guild (expensive operation)"""
41
  if self.initialized:
42
  return
43
-
44
  print("INFO: Initializing RagBot workflow...")
45
  start_time = time.time()
46
-
47
  import os
48
-
49
  try:
50
  # Set working directory via environment so vector store paths resolve
51
  # without a process-global os.chdir() (which is thread-unsafe).
52
  ragbot_root = Path(__file__).parent.parent.parent.parent
53
  os.environ["RAGBOT_ROOT"] = str(ragbot_root)
54
  print(f"INFO: Project root: {ragbot_root}")
55
-
56
  # Temporarily chdir only during initialization (single-threaded at startup)
57
  original_dir = os.getcwd()
58
  os.chdir(ragbot_root)
59
-
60
  self.guild = create_guild()
61
  self.initialized = True
62
  self.init_time = datetime.now()
63
-
64
  elapsed = (time.time() - start_time) * 1000
65
  print(f"OK: RagBot initialized successfully ({elapsed:.0f}ms)")
66
-
67
  except Exception as e:
68
  print(f"ERROR: Failed to initialize RagBot: {e}")
69
  raise
70
-
71
  finally:
72
  # Restore original directory
73
  os.chdir(original_dir)
74
-
75
  def get_uptime_seconds(self) -> float:
76
  """Get API uptime in seconds"""
77
  if not self.init_time:
78
  return 0.0
79
  return (datetime.now() - self.init_time).total_seconds()
80
-
81
  def is_ready(self) -> bool:
82
  """Check if service is ready to handle requests"""
83
  return self.initialized and self.guild is not None
84
-
85
  def analyze(
86
  self,
87
- biomarkers: Dict[str, float],
88
- patient_context: Dict[str, Any],
89
- model_prediction: Dict[str, Any],
90
- extracted_biomarkers: Dict[str, float] = None
91
  ) -> AnalysisResponse:
92
  """
93
  Run complete analysis workflow and format full detailed response.
@@ -103,10 +110,10 @@ class RagBotService:
103
  """
104
  if not self.is_ready():
105
  raise RuntimeError("RagBot service not initialized. Call initialize() first.")
106
-
107
  request_id = f"req_{uuid.uuid4().hex[:12]}"
108
  start_time = time.time()
109
-
110
  try:
111
  # Create PatientInput
112
  patient_input = PatientInput(
@@ -114,13 +121,13 @@ class RagBotService:
114
  model_prediction=model_prediction,
115
  patient_context=patient_context
116
  )
117
-
118
  # Run workflow
119
  workflow_result = self.guild.run(patient_input)
120
-
121
  # Calculate processing time
122
  processing_time_ms = (time.time() - start_time) * 1000
123
-
124
  # Format response
125
  response = self._format_response(
126
  request_id=request_id,
@@ -131,21 +138,21 @@ class RagBotService:
131
  model_prediction=model_prediction,
132
  processing_time_ms=processing_time_ms
133
  )
134
-
135
  return response
136
-
137
  except Exception as e:
138
  # Re-raise with context
139
- raise RuntimeError(f"Analysis failed during workflow execution: {str(e)}") from e
140
-
141
  def _format_response(
142
  self,
143
  request_id: str,
144
- workflow_result: Dict[str, Any],
145
- input_biomarkers: Dict[str, float],
146
- extracted_biomarkers: Dict[str, float],
147
- patient_context: Dict[str, Any],
148
- model_prediction: Dict[str, Any],
149
  processing_time_ms: float
150
  ) -> AnalysisResponse:
151
  """
@@ -159,17 +166,17 @@ class RagBotService:
159
  - safety_alerts: list of SafetyAlert objects
160
  - sop_version, processing_timestamp, etc.
161
  """
162
-
163
  # The synthesizer output is nested inside final_response
164
  final_response = workflow_result.get("final_response", {}) or {}
165
-
166
  # Extract main prediction
167
  prediction = Prediction(
168
  disease=model_prediction["disease"],
169
  confidence=model_prediction["confidence"],
170
  probabilities=model_prediction.get("probabilities", {})
171
  )
172
-
173
  # Biomarker flags: prefer state-level data (BiomarkerFlag objects from validator),
174
  # fall back to synthesizer output
175
  state_flags = workflow_result.get("biomarker_flags", [])
@@ -188,7 +195,7 @@ class RagBotService:
188
  BiomarkerFlag(**flag) if isinstance(flag, dict) else BiomarkerFlag(**flag.model_dump())
189
  for flag in biomarker_flags_source
190
  ]
191
-
192
  # Safety alerts: prefer state-level data, fall back to synthesizer
193
  state_alerts = workflow_result.get("safety_alerts", [])
194
  if state_alerts:
@@ -206,7 +213,7 @@ class RagBotService:
206
  SafetyAlert(**alert) if isinstance(alert, dict) else SafetyAlert(**alert.model_dump())
207
  for alert in safety_alerts_source
208
  ]
209
-
210
  # Extract key drivers from synthesizer output
211
  key_drivers_data = final_response.get("key_drivers", [])
212
  if not key_drivers_data:
@@ -215,7 +222,7 @@ class RagBotService:
215
  for driver in key_drivers_data:
216
  if isinstance(driver, dict):
217
  key_drivers.append(KeyDriver(**driver))
218
-
219
  # Disease explanation from synthesizer
220
  disease_exp_data = final_response.get("disease_explanation", {})
221
  if not disease_exp_data:
@@ -225,7 +232,7 @@ class RagBotService:
225
  citations=disease_exp_data.get("citations", []),
226
  retrieved_chunks=disease_exp_data.get("retrieved_chunks")
227
  )
228
-
229
  # Recommendations from synthesizer
230
  recs_data = final_response.get("recommendations", {})
231
  if not recs_data:
@@ -238,7 +245,7 @@ class RagBotService:
238
  monitoring=recs_data.get("monitoring", []),
239
  follow_up=recs_data.get("follow_up")
240
  )
241
-
242
  # Confidence assessment from synthesizer
243
  conf_data = final_response.get("confidence_assessment", {})
244
  if not conf_data:
@@ -249,12 +256,12 @@ class RagBotService:
249
  limitations=conf_data.get("limitations", []),
250
  reasoning=conf_data.get("reasoning")
251
  )
252
-
253
  # Alternative diagnoses
254
  alternative_diagnoses = final_response.get("alternative_diagnoses")
255
  if alternative_diagnoses is None:
256
  alternative_diagnoses = final_response.get("analysis", {}).get("alternative_diagnoses")
257
-
258
  # Assemble complete analysis
259
  analysis = Analysis(
260
  biomarker_flags=biomarker_flags,
@@ -265,7 +272,7 @@ class RagBotService:
265
  confidence_assessment=confidence_assessment,
266
  alternative_diagnoses=alternative_diagnoses
267
  )
268
-
269
  # Agent outputs from state (these are src.state.AgentOutput objects)
270
  agent_outputs_data = workflow_result.get("agent_outputs", [])
271
  agent_outputs = []
@@ -274,7 +281,7 @@ class RagBotService:
274
  agent_outputs.append(AgentOutput(**agent_out.model_dump()))
275
  elif isinstance(agent_out, dict):
276
  agent_outputs.append(AgentOutput(**agent_out))
277
-
278
  # Workflow metadata
279
  workflow_metadata = {
280
  "sop_version": workflow_result.get("sop_version"),
@@ -282,12 +289,12 @@ class RagBotService:
282
  "agents_executed": len(agent_outputs),
283
  "workflow_success": True
284
  }
285
-
286
  # Conversational summary (if available)
287
  conversational_summary = final_response.get("conversational_summary")
288
  if not conversational_summary:
289
  conversational_summary = final_response.get("patient_summary", {}).get("narrative")
290
-
291
  # Generate conversational summary if not present
292
  if not conversational_summary:
293
  conversational_summary = self._generate_conversational_summary(
@@ -296,7 +303,7 @@ class RagBotService:
296
  key_drivers=key_drivers,
297
  recommendations=recommendations
298
  )
299
-
300
  # Assemble final response
301
  response = AnalysisResponse(
302
  status="success",
@@ -313,9 +320,9 @@ class RagBotService:
313
  processing_time_ms=processing_time_ms,
314
  sop_version=workflow_result.get("sop_version", "Baseline")
315
  )
316
-
317
  return response
318
-
319
  def _generate_conversational_summary(
320
  self,
321
  prediction: Prediction,
@@ -324,37 +331,37 @@ class RagBotService:
324
  recommendations: Recommendations
325
  ) -> str:
326
  """Generate a simple conversational summary"""
327
-
328
  summary_parts = []
329
  summary_parts.append("Hi there!\n")
330
  summary_parts.append("Based on your biomarkers, I analyzed your results.\n")
331
-
332
  # Prediction
333
  summary_parts.append(f"\nPrimary Finding: {prediction.disease}")
334
  summary_parts.append(f" Confidence: {prediction.confidence:.0%}\n")
335
-
336
  # Safety alerts
337
  if safety_alerts:
338
  summary_parts.append("\nIMPORTANT SAFETY ALERTS:")
339
  for alert in safety_alerts[:3]: # Top 3
340
  summary_parts.append(f" - {alert.biomarker}: {alert.message}")
341
  summary_parts.append(f" Action: {alert.action}")
342
-
343
  # Key drivers
344
  if key_drivers:
345
  summary_parts.append("\nWhy this prediction?")
346
  for driver in key_drivers[:3]: # Top 3
347
  summary_parts.append(f" - {driver.biomarker} ({driver.value}): {driver.explanation[:100]}...")
348
-
349
  # Recommendations
350
  if recommendations.immediate_actions:
351
  summary_parts.append("\nWhat You Should Do:")
352
  for i, action in enumerate(recommendations.immediate_actions[:3], 1):
353
  summary_parts.append(f" {i}. {action}")
354
-
355
  summary_parts.append("\nImportant: This is an AI-assisted analysis, NOT medical advice.")
356
  summary_parts.append(" Please consult a healthcare professional for proper diagnosis and treatment.")
357
-
358
  return "\n".join(summary_parts)
359
 
360
 
 
6
  import sys
7
  import time
8
  import uuid
 
 
9
  from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Any
12
 
13
  # Ensure project root is in path for src imports
14
  _project_root = str(Path(__file__).parent.parent.parent.parent)
15
  if _project_root not in sys.path:
16
  sys.path.insert(0, _project_root)
17
 
 
 
18
  from app.models.schemas import (
19
+ AgentOutput,
20
+ Analysis,
21
+ AnalysisResponse,
22
+ BiomarkerFlag,
23
+ ConfidenceAssessment,
24
+ DiseaseExplanation,
25
+ KeyDriver,
26
+ Prediction,
27
+ Recommendations,
28
+ SafetyAlert,
29
  )
30
+ from src.state import PatientInput
31
+ from src.workflow import create_guild
32
 
33
 
34
  class RagBotService:
 
36
  Service class to manage RagBot workflow lifecycle.
37
  Initializes once, then handles multiple analysis requests.
38
  """
39
+
40
  def __init__(self):
41
  """Initialize the workflow (loads vector store, models, etc.)"""
42
  self.guild = None
43
  self.initialized = False
44
  self.init_time = None
45
+
46
  def initialize(self):
47
  """Initialize the Clinical Insight Guild (expensive operation)"""
48
  if self.initialized:
49
  return
50
+
51
  print("INFO: Initializing RagBot workflow...")
52
  start_time = time.time()
53
+
54
  import os
55
+
56
  try:
57
  # Set working directory via environment so vector store paths resolve
58
  # without a process-global os.chdir() (which is thread-unsafe).
59
  ragbot_root = Path(__file__).parent.parent.parent.parent
60
  os.environ["RAGBOT_ROOT"] = str(ragbot_root)
61
  print(f"INFO: Project root: {ragbot_root}")
62
+
63
  # Temporarily chdir only during initialization (single-threaded at startup)
64
  original_dir = os.getcwd()
65
  os.chdir(ragbot_root)
66
+
67
  self.guild = create_guild()
68
  self.initialized = True
69
  self.init_time = datetime.now()
70
+
71
  elapsed = (time.time() - start_time) * 1000
72
  print(f"OK: RagBot initialized successfully ({elapsed:.0f}ms)")
73
+
74
  except Exception as e:
75
  print(f"ERROR: Failed to initialize RagBot: {e}")
76
  raise
77
+
78
  finally:
79
  # Restore original directory
80
  os.chdir(original_dir)
81
+
82
  def get_uptime_seconds(self) -> float:
83
  """Get API uptime in seconds"""
84
  if not self.init_time:
85
  return 0.0
86
  return (datetime.now() - self.init_time).total_seconds()
87
+
88
  def is_ready(self) -> bool:
89
  """Check if service is ready to handle requests"""
90
  return self.initialized and self.guild is not None
91
+
92
  def analyze(
93
  self,
94
+ biomarkers: dict[str, float],
95
+ patient_context: dict[str, Any],
96
+ model_prediction: dict[str, Any],
97
+ extracted_biomarkers: dict[str, float] = None
98
  ) -> AnalysisResponse:
99
  """
100
  Run complete analysis workflow and format full detailed response.
 
110
  """
111
  if not self.is_ready():
112
  raise RuntimeError("RagBot service not initialized. Call initialize() first.")
113
+
114
  request_id = f"req_{uuid.uuid4().hex[:12]}"
115
  start_time = time.time()
116
+
117
  try:
118
  # Create PatientInput
119
  patient_input = PatientInput(
 
121
  model_prediction=model_prediction,
122
  patient_context=patient_context
123
  )
124
+
125
  # Run workflow
126
  workflow_result = self.guild.run(patient_input)
127
+
128
  # Calculate processing time
129
  processing_time_ms = (time.time() - start_time) * 1000
130
+
131
  # Format response
132
  response = self._format_response(
133
  request_id=request_id,
 
138
  model_prediction=model_prediction,
139
  processing_time_ms=processing_time_ms
140
  )
141
+
142
  return response
143
+
144
  except Exception as e:
145
  # Re-raise with context
146
+ raise RuntimeError(f"Analysis failed during workflow execution: {e!s}") from e
147
+
148
  def _format_response(
149
  self,
150
  request_id: str,
151
+ workflow_result: dict[str, Any],
152
+ input_biomarkers: dict[str, float],
153
+ extracted_biomarkers: dict[str, float],
154
+ patient_context: dict[str, Any],
155
+ model_prediction: dict[str, Any],
156
  processing_time_ms: float
157
  ) -> AnalysisResponse:
158
  """
 
166
  - safety_alerts: list of SafetyAlert objects
167
  - sop_version, processing_timestamp, etc.
168
  """
169
+
170
  # The synthesizer output is nested inside final_response
171
  final_response = workflow_result.get("final_response", {}) or {}
172
+
173
  # Extract main prediction
174
  prediction = Prediction(
175
  disease=model_prediction["disease"],
176
  confidence=model_prediction["confidence"],
177
  probabilities=model_prediction.get("probabilities", {})
178
  )
179
+
180
  # Biomarker flags: prefer state-level data (BiomarkerFlag objects from validator),
181
  # fall back to synthesizer output
182
  state_flags = workflow_result.get("biomarker_flags", [])
 
195
  BiomarkerFlag(**flag) if isinstance(flag, dict) else BiomarkerFlag(**flag.model_dump())
196
  for flag in biomarker_flags_source
197
  ]
198
+
199
  # Safety alerts: prefer state-level data, fall back to synthesizer
200
  state_alerts = workflow_result.get("safety_alerts", [])
201
  if state_alerts:
 
213
  SafetyAlert(**alert) if isinstance(alert, dict) else SafetyAlert(**alert.model_dump())
214
  for alert in safety_alerts_source
215
  ]
216
+
217
  # Extract key drivers from synthesizer output
218
  key_drivers_data = final_response.get("key_drivers", [])
219
  if not key_drivers_data:
 
222
  for driver in key_drivers_data:
223
  if isinstance(driver, dict):
224
  key_drivers.append(KeyDriver(**driver))
225
+
226
  # Disease explanation from synthesizer
227
  disease_exp_data = final_response.get("disease_explanation", {})
228
  if not disease_exp_data:
 
232
  citations=disease_exp_data.get("citations", []),
233
  retrieved_chunks=disease_exp_data.get("retrieved_chunks")
234
  )
235
+
236
  # Recommendations from synthesizer
237
  recs_data = final_response.get("recommendations", {})
238
  if not recs_data:
 
245
  monitoring=recs_data.get("monitoring", []),
246
  follow_up=recs_data.get("follow_up")
247
  )
248
+
249
  # Confidence assessment from synthesizer
250
  conf_data = final_response.get("confidence_assessment", {})
251
  if not conf_data:
 
256
  limitations=conf_data.get("limitations", []),
257
  reasoning=conf_data.get("reasoning")
258
  )
259
+
260
  # Alternative diagnoses
261
  alternative_diagnoses = final_response.get("alternative_diagnoses")
262
  if alternative_diagnoses is None:
263
  alternative_diagnoses = final_response.get("analysis", {}).get("alternative_diagnoses")
264
+
265
  # Assemble complete analysis
266
  analysis = Analysis(
267
  biomarker_flags=biomarker_flags,
 
272
  confidence_assessment=confidence_assessment,
273
  alternative_diagnoses=alternative_diagnoses
274
  )
275
+
276
  # Agent outputs from state (these are src.state.AgentOutput objects)
277
  agent_outputs_data = workflow_result.get("agent_outputs", [])
278
  agent_outputs = []
 
281
  agent_outputs.append(AgentOutput(**agent_out.model_dump()))
282
  elif isinstance(agent_out, dict):
283
  agent_outputs.append(AgentOutput(**agent_out))
284
+
285
  # Workflow metadata
286
  workflow_metadata = {
287
  "sop_version": workflow_result.get("sop_version"),
 
289
  "agents_executed": len(agent_outputs),
290
  "workflow_success": True
291
  }
292
+
293
  # Conversational summary (if available)
294
  conversational_summary = final_response.get("conversational_summary")
295
  if not conversational_summary:
296
  conversational_summary = final_response.get("patient_summary", {}).get("narrative")
297
+
298
  # Generate conversational summary if not present
299
  if not conversational_summary:
300
  conversational_summary = self._generate_conversational_summary(
 
303
  key_drivers=key_drivers,
304
  recommendations=recommendations
305
  )
306
+
307
  # Assemble final response
308
  response = AnalysisResponse(
309
  status="success",
 
320
  processing_time_ms=processing_time_ms,
321
  sop_version=workflow_result.get("sop_version", "Baseline")
322
  )
323
+
324
  return response
325
+
326
  def _generate_conversational_summary(
327
  self,
328
  prediction: Prediction,
 
331
  recommendations: Recommendations
332
  ) -> str:
333
  """Generate a simple conversational summary"""
334
+
335
  summary_parts = []
336
  summary_parts.append("Hi there!\n")
337
  summary_parts.append("Based on your biomarkers, I analyzed your results.\n")
338
+
339
  # Prediction
340
  summary_parts.append(f"\nPrimary Finding: {prediction.disease}")
341
  summary_parts.append(f" Confidence: {prediction.confidence:.0%}\n")
342
+
343
  # Safety alerts
344
  if safety_alerts:
345
  summary_parts.append("\nIMPORTANT SAFETY ALERTS:")
346
  for alert in safety_alerts[:3]: # Top 3
347
  summary_parts.append(f" - {alert.biomarker}: {alert.message}")
348
  summary_parts.append(f" Action: {alert.action}")
349
+
350
  # Key drivers
351
  if key_drivers:
352
  summary_parts.append("\nWhy this prediction?")
353
  for driver in key_drivers[:3]: # Top 3
354
  summary_parts.append(f" - {driver.biomarker} ({driver.value}): {driver.explanation[:100]}...")
355
+
356
  # Recommendations
357
  if recommendations.immediate_actions:
358
  summary_parts.append("\nWhat You Should Do:")
359
  for i, action in enumerate(recommendations.immediate_actions[:3], 1):
360
  summary_parts.append(f" {i}. {action}")
361
+
362
  summary_parts.append("\nImportant: This is an AI-assisted analysis, NOT medical advice.")
363
  summary_parts.append(" Please consult a healthcare professional for proper diagnosis and treatment.")
364
+
365
  return "\n".join(summary_parts)
366
 
367
 
archive/evolution/__init__.py CHANGED
@@ -4,32 +4,26 @@ Self-improvement system for SOP optimization
4
  """
5
 
6
  from .director import (
7
- SOPGenePool,
8
  Diagnosis,
9
- SOPMutation,
10
  EvolvedSOPs,
 
 
11
  performance_diagnostician,
 
12
  sop_architect,
13
- run_evolution_cycle
14
- )
15
-
16
- from .pareto import (
17
- identify_pareto_front,
18
- visualize_pareto_frontier,
19
- print_pareto_summary,
20
- analyze_improvements
21
  )
 
22
 
23
  __all__ = [
24
- 'SOPGenePool',
25
  'Diagnosis',
26
- 'SOPMutation',
27
  'EvolvedSOPs',
28
- 'performance_diagnostician',
29
- 'sop_architect',
30
- 'run_evolution_cycle',
31
  'identify_pareto_front',
32
- 'visualize_pareto_frontier',
33
  'print_pareto_summary',
34
- 'analyze_improvements'
 
 
35
  ]
 
4
  """
5
 
6
  from .director import (
 
7
  Diagnosis,
 
8
  EvolvedSOPs,
9
+ SOPGenePool,
10
+ SOPMutation,
11
  performance_diagnostician,
12
+ run_evolution_cycle,
13
  sop_architect,
 
 
 
 
 
 
 
 
14
  )
15
+ from .pareto import analyze_improvements, identify_pareto_front, print_pareto_summary, visualize_pareto_frontier
16
 
17
  __all__ = [
 
18
  'Diagnosis',
 
19
  'EvolvedSOPs',
20
+ 'SOPGenePool',
21
+ 'SOPMutation',
22
+ 'analyze_improvements',
23
  'identify_pareto_front',
24
+ 'performance_diagnostician',
25
  'print_pareto_summary',
26
+ 'run_evolution_cycle',
27
+ 'sop_architect',
28
+ 'visualize_pareto_frontier'
29
  ]
archive/evolution/director.py CHANGED
@@ -3,27 +3,28 @@ MediGuard AI RAG-Helper - Evolution Engine
3
  Outer Loop Director for SOP Evolution
4
  """
5
 
6
- import json
7
- from typing import Any, Callable, Dict, List, Literal, Optional
 
8
  from pydantic import BaseModel, Field
9
- from langchain_core.prompts import ChatPromptTemplate
10
  from src.config import ExplanationSOP
11
  from src.evaluation.evaluators import EvaluationResult
12
 
13
 
14
  class SOPGenePool:
15
  """Manages version control for evolving SOPs"""
16
-
17
  def __init__(self):
18
- self.pool: List[Dict[str, Any]] = []
19
- self.gene_pool: List[Dict[str, Any]] = [] # Alias for compatibility
20
  self.version_counter = 0
21
-
22
  def add(
23
  self,
24
  sop: ExplanationSOP,
25
  evaluation: EvaluationResult,
26
- parent_version: Optional[int] = None,
27
  description: str = ""
28
  ):
29
  """Add a new SOP to the gene pool"""
@@ -38,50 +39,50 @@ class SOPGenePool:
38
  self.pool.append(entry)
39
  self.gene_pool = self.pool # Keep in sync
40
  print(f"✓ Added SOP v{self.version_counter} to gene pool: {description}")
41
-
42
- def get_latest(self) -> Optional[Dict[str, Any]]:
43
  """Get the most recent SOP"""
44
  return self.pool[-1] if self.pool else None
45
-
46
- def get_by_version(self, version: int) -> Optional[Dict[str, Any]]:
47
  """Retrieve specific SOP version"""
48
  for entry in self.pool:
49
  if entry['version'] == version:
50
  return entry
51
  return None
52
-
53
- def get_best_by_metric(self, metric: str) -> Optional[Dict[str, Any]]:
54
  """Get SOP with highest score on specific metric"""
55
  if not self.pool:
56
  return None
57
-
58
  best = max(
59
  self.pool,
60
  key=lambda x: getattr(x['evaluation'], metric).score
61
  )
62
  return best
63
-
64
  def summary(self):
65
  """Print summary of all SOPs in pool"""
66
  print("\n" + "=" * 80)
67
  print("SOP GENE POOL SUMMARY")
68
  print("=" * 80)
69
-
70
  for entry in self.pool:
71
  v = entry['version']
72
  p = entry['parent']
73
  desc = entry['description']
74
  e = entry['evaluation']
75
-
76
  parent_str = "(Baseline)" if p is None else f"(Child of v{p})"
77
-
78
  print(f"\nSOP v{v} {parent_str}: {desc}")
79
  print(f" Clinical Accuracy: {e.clinical_accuracy.score:.2f}")
80
  print(f" Evidence Grounding: {e.evidence_grounding.score:.2f}")
81
  print(f" Actionability: {e.actionability.score:.2f}")
82
  print(f" Clarity: {e.clarity.score:.2f}")
83
  print(f" Safety & Completeness: {e.safety_completeness.score:.2f}")
84
-
85
  print("\n" + "=" * 80)
86
 
87
 
@@ -120,7 +121,7 @@ class SOPMutation(BaseModel):
120
 
121
  class EvolvedSOPs(BaseModel):
122
  """Container for mutated SOPs from Architect"""
123
- mutations: List[SOPMutation]
124
 
125
 
126
  def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
@@ -131,7 +132,7 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
131
  print("\n" + "=" * 70)
132
  print("EXECUTING: Performance Diagnostician")
133
  print("=" * 70)
134
-
135
  # Find lowest score programmatically (no LLM needed)
136
  scores = {
137
  'clinical_accuracy': evaluation.clinical_accuracy.score,
@@ -140,7 +141,7 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
140
  'clarity': evaluation.clarity.score,
141
  'safety_completeness': evaluation.safety_completeness.score
142
  }
143
-
144
  reasonings = {
145
  'clinical_accuracy': evaluation.clinical_accuracy.reasoning,
146
  'evidence_grounding': evaluation.evidence_grounding.reasoning,
@@ -148,11 +149,11 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
148
  'clarity': evaluation.clarity.reasoning,
149
  'safety_completeness': evaluation.safety_completeness.reasoning
150
  }
151
-
152
  primary_weakness = min(scores, key=scores.get)
153
  weakness_score = scores[primary_weakness]
154
  weakness_reasoning = reasonings[primary_weakness]
155
-
156
  # Generate detailed root cause analysis
157
  root_cause_map = {
158
  'clinical_accuracy': f"Clinical accuracy score ({weakness_score:.2f}) indicates potential issues with medical interpretations. {weakness_reasoning[:200]}",
@@ -161,7 +162,7 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
161
  'clarity': f"Clarity score ({weakness_score:.2f}) suggests readability issues. {weakness_reasoning[:200]}",
162
  'safety_completeness': f"Safety score ({weakness_score:.2f}) indicates missing risk discussions. {weakness_reasoning[:200]}"
163
  }
164
-
165
  recommendation_map = {
166
  'clinical_accuracy': "Increase RAG depth to access more authoritative medical sources.",
167
  'evidence_grounding': "Enforce strict citation requirements and increase RAG depth.",
@@ -169,17 +170,17 @@ def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
169
  'clarity': "Simplify language and reduce technical jargon for better readability.",
170
  'safety_completeness': "Add explicit safety warnings and ensure complete risk coverage."
171
  }
172
-
173
  diagnosis = Diagnosis(
174
  primary_weakness=primary_weakness,
175
  root_cause_analysis=root_cause_map[primary_weakness],
176
  recommendation=recommendation_map[primary_weakness]
177
  )
178
-
179
- print(f"\n✓ Diagnosis complete")
180
  print(f" Primary weakness: {diagnosis.primary_weakness} ({weakness_score:.3f})")
181
  print(f" Recommendation: {diagnosis.recommendation}")
182
-
183
  return diagnosis
184
 
185
 
@@ -195,9 +196,9 @@ def sop_architect(
195
  print("EXECUTING: SOP Architect")
196
  print("=" * 70)
197
  print(f"Target weakness: {diagnosis.primary_weakness}")
198
-
199
  weakness = diagnosis.primary_weakness
200
-
201
  # Generate mutations based on weakness type
202
  if weakness == 'clarity':
203
  mut1 = SOPMutation(
@@ -226,7 +227,7 @@ def sop_architect(
226
  critical_value_alert_mode=current_sop.critical_value_alert_mode,
227
  description="Balanced detail with fewer citations for readability"
228
  )
229
-
230
  elif weakness == 'evidence_grounding':
231
  mut1 = SOPMutation(
232
  disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
@@ -254,7 +255,7 @@ def sop_architect(
254
  critical_value_alert_mode=current_sop.critical_value_alert_mode,
255
  description="Moderate RAG increase with citation enforcement"
256
  )
257
-
258
  elif weakness == 'actionability':
259
  mut1 = SOPMutation(
260
  disease_explainer_k=current_sop.disease_explainer_k,
@@ -282,7 +283,7 @@ def sop_architect(
282
  critical_value_alert_mode='strict',
283
  description="Comprehensive approach with all agents enabled"
284
  )
285
-
286
  elif weakness == 'clinical_accuracy':
287
  mut1 = SOPMutation(
288
  disease_explainer_k=10,
@@ -310,7 +311,7 @@ def sop_architect(
310
  critical_value_alert_mode='strict',
311
  description="High RAG depth with comprehensive detail"
312
  )
313
-
314
  else: # safety_completeness
315
  mut1 = SOPMutation(
316
  disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
@@ -338,14 +339,14 @@ def sop_architect(
338
  critical_value_alert_mode='strict',
339
  description="Maximum coverage with all safety features"
340
  )
341
-
342
  evolved = EvolvedSOPs(mutations=[mut1, mut2])
343
-
344
  print(f"\n✓ Generated {len(evolved.mutations)} mutations")
345
  for i, mut in enumerate(evolved.mutations, 1):
346
  print(f" {i}. {mut.description}")
347
  print(f" Disease K: {mut.disease_explainer_k}, Detail: {mut.explainer_detail_level}")
348
-
349
  return evolved
350
 
351
 
@@ -354,7 +355,7 @@ def run_evolution_cycle(
354
  patient_input: Any,
355
  workflow_graph: Any,
356
  evaluation_func: Callable
357
- ) -> List[Dict[str, Any]]:
358
  """
359
  Executes one complete evolution cycle:
360
  1. Diagnose current best SOP
@@ -367,38 +368,37 @@ def run_evolution_cycle(
367
  print("\n" + "=" * 80)
368
  print("STARTING EVOLUTION CYCLE")
369
  print("=" * 80)
370
-
371
  # Get current best (for simplicity, use latest)
372
  current_best = gene_pool.get_latest()
373
  if not current_best:
374
  raise ValueError("Gene pool is empty. Add baseline SOP first.")
375
-
376
  parent_sop = current_best['sop']
377
  parent_eval = current_best['evaluation']
378
  parent_version = current_best['version']
379
-
380
  print(f"\nImproving upon SOP v{parent_version}")
381
-
382
  # Step 1: Diagnose
383
  diagnosis = performance_diagnostician(parent_eval)
384
-
385
  # Step 2: Generate mutations
386
  evolved_sops = sop_architect(diagnosis, parent_sop)
387
-
388
  # Step 3: Test each mutation
389
  new_entries = []
390
  for i, mutant_sop_model in enumerate(evolved_sops.mutations, 1):
391
  print(f"\n{'=' * 70}")
392
  print(f"TESTING MUTATION {i}/{len(evolved_sops.mutations)}: {mutant_sop_model.description}")
393
  print("=" * 70)
394
-
395
  # Convert SOPMutation to ExplanationSOP
396
  mutant_sop_dict = mutant_sop_model.model_dump()
397
  description = mutant_sop_dict.pop('description')
398
  mutant_sop = ExplanationSOP(**mutant_sop_dict)
399
-
400
  # Run workflow with mutated SOP
401
- from src.state import PatientInput
402
  from datetime import datetime
403
  graph_input = {
404
  "patient_biomarkers": patient_input.biomarkers,
@@ -414,17 +414,17 @@ def run_evolution_cycle(
414
  "processing_timestamp": datetime.now().isoformat(),
415
  "sop_version": description
416
  }
417
-
418
  try:
419
  final_state = workflow_graph.invoke(graph_input)
420
-
421
  # Evaluate output
422
  evaluation = evaluation_func(
423
  final_response=final_state['final_response'],
424
  agent_outputs=final_state['agent_outputs'],
425
  biomarkers=patient_input.biomarkers
426
  )
427
-
428
  # Add to gene pool
429
  gene_pool.add(
430
  sop=mutant_sop,
@@ -432,7 +432,7 @@ def run_evolution_cycle(
432
  parent_version=parent_version,
433
  description=description
434
  )
435
-
436
  new_entries.append({
437
  "sop": mutant_sop,
438
  "evaluation": evaluation,
@@ -441,9 +441,9 @@ def run_evolution_cycle(
441
  except Exception as e:
442
  print(f"❌ Mutation {i} failed: {e}")
443
  continue
444
-
445
  print("\n" + "=" * 80)
446
  print("EVOLUTION CYCLE COMPLETE")
447
  print("=" * 80)
448
-
449
  return new_entries
 
3
  Outer Loop Director for SOP Evolution
4
  """
5
 
6
+ from collections.abc import Callable
7
+ from typing import Any, Literal
8
+
9
  from pydantic import BaseModel, Field
10
+
11
  from src.config import ExplanationSOP
12
  from src.evaluation.evaluators import EvaluationResult
13
 
14
 
15
  class SOPGenePool:
16
  """Manages version control for evolving SOPs"""
17
+
18
  def __init__(self):
19
+ self.pool: list[dict[str, Any]] = []
20
+ self.gene_pool: list[dict[str, Any]] = [] # Alias for compatibility
21
  self.version_counter = 0
22
+
23
  def add(
24
  self,
25
  sop: ExplanationSOP,
26
  evaluation: EvaluationResult,
27
+ parent_version: int | None = None,
28
  description: str = ""
29
  ):
30
  """Add a new SOP to the gene pool"""
 
39
  self.pool.append(entry)
40
  self.gene_pool = self.pool # Keep in sync
41
  print(f"✓ Added SOP v{self.version_counter} to gene pool: {description}")
42
+
43
+ def get_latest(self) -> dict[str, Any] | None:
44
  """Get the most recent SOP"""
45
  return self.pool[-1] if self.pool else None
46
+
47
+ def get_by_version(self, version: int) -> dict[str, Any] | None:
48
  """Retrieve specific SOP version"""
49
  for entry in self.pool:
50
  if entry['version'] == version:
51
  return entry
52
  return None
53
+
54
+ def get_best_by_metric(self, metric: str) -> dict[str, Any] | None:
55
  """Get SOP with highest score on specific metric"""
56
  if not self.pool:
57
  return None
58
+
59
  best = max(
60
  self.pool,
61
  key=lambda x: getattr(x['evaluation'], metric).score
62
  )
63
  return best
64
+
65
  def summary(self):
66
  """Print summary of all SOPs in pool"""
67
  print("\n" + "=" * 80)
68
  print("SOP GENE POOL SUMMARY")
69
  print("=" * 80)
70
+
71
  for entry in self.pool:
72
  v = entry['version']
73
  p = entry['parent']
74
  desc = entry['description']
75
  e = entry['evaluation']
76
+
77
  parent_str = "(Baseline)" if p is None else f"(Child of v{p})"
78
+
79
  print(f"\nSOP v{v} {parent_str}: {desc}")
80
  print(f" Clinical Accuracy: {e.clinical_accuracy.score:.2f}")
81
  print(f" Evidence Grounding: {e.evidence_grounding.score:.2f}")
82
  print(f" Actionability: {e.actionability.score:.2f}")
83
  print(f" Clarity: {e.clarity.score:.2f}")
84
  print(f" Safety & Completeness: {e.safety_completeness.score:.2f}")
85
+
86
  print("\n" + "=" * 80)
87
 
88
 
 
121
 
122
  class EvolvedSOPs(BaseModel):
123
  """Container for mutated SOPs from Architect"""
124
+ mutations: list[SOPMutation]
125
 
126
 
127
  def performance_diagnostician(evaluation: EvaluationResult) -> Diagnosis:
 
132
  print("\n" + "=" * 70)
133
  print("EXECUTING: Performance Diagnostician")
134
  print("=" * 70)
135
+
136
  # Find lowest score programmatically (no LLM needed)
137
  scores = {
138
  'clinical_accuracy': evaluation.clinical_accuracy.score,
 
141
  'clarity': evaluation.clarity.score,
142
  'safety_completeness': evaluation.safety_completeness.score
143
  }
144
+
145
  reasonings = {
146
  'clinical_accuracy': evaluation.clinical_accuracy.reasoning,
147
  'evidence_grounding': evaluation.evidence_grounding.reasoning,
 
149
  'clarity': evaluation.clarity.reasoning,
150
  'safety_completeness': evaluation.safety_completeness.reasoning
151
  }
152
+
153
  primary_weakness = min(scores, key=scores.get)
154
  weakness_score = scores[primary_weakness]
155
  weakness_reasoning = reasonings[primary_weakness]
156
+
157
  # Generate detailed root cause analysis
158
  root_cause_map = {
159
  'clinical_accuracy': f"Clinical accuracy score ({weakness_score:.2f}) indicates potential issues with medical interpretations. {weakness_reasoning[:200]}",
 
162
  'clarity': f"Clarity score ({weakness_score:.2f}) suggests readability issues. {weakness_reasoning[:200]}",
163
  'safety_completeness': f"Safety score ({weakness_score:.2f}) indicates missing risk discussions. {weakness_reasoning[:200]}"
164
  }
165
+
166
  recommendation_map = {
167
  'clinical_accuracy': "Increase RAG depth to access more authoritative medical sources.",
168
  'evidence_grounding': "Enforce strict citation requirements and increase RAG depth.",
 
170
  'clarity': "Simplify language and reduce technical jargon for better readability.",
171
  'safety_completeness': "Add explicit safety warnings and ensure complete risk coverage."
172
  }
173
+
174
  diagnosis = Diagnosis(
175
  primary_weakness=primary_weakness,
176
  root_cause_analysis=root_cause_map[primary_weakness],
177
  recommendation=recommendation_map[primary_weakness]
178
  )
179
+
180
+ print("\n✓ Diagnosis complete")
181
  print(f" Primary weakness: {diagnosis.primary_weakness} ({weakness_score:.3f})")
182
  print(f" Recommendation: {diagnosis.recommendation}")
183
+
184
  return diagnosis
185
 
186
 
 
196
  print("EXECUTING: SOP Architect")
197
  print("=" * 70)
198
  print(f"Target weakness: {diagnosis.primary_weakness}")
199
+
200
  weakness = diagnosis.primary_weakness
201
+
202
  # Generate mutations based on weakness type
203
  if weakness == 'clarity':
204
  mut1 = SOPMutation(
 
227
  critical_value_alert_mode=current_sop.critical_value_alert_mode,
228
  description="Balanced detail with fewer citations for readability"
229
  )
230
+
231
  elif weakness == 'evidence_grounding':
232
  mut1 = SOPMutation(
233
  disease_explainer_k=min(10, current_sop.disease_explainer_k + 2),
 
255
  critical_value_alert_mode=current_sop.critical_value_alert_mode,
256
  description="Moderate RAG increase with citation enforcement"
257
  )
258
+
259
  elif weakness == 'actionability':
260
  mut1 = SOPMutation(
261
  disease_explainer_k=current_sop.disease_explainer_k,
 
283
  critical_value_alert_mode='strict',
284
  description="Comprehensive approach with all agents enabled"
285
  )
286
+
287
  elif weakness == 'clinical_accuracy':
288
  mut1 = SOPMutation(
289
  disease_explainer_k=10,
 
311
  critical_value_alert_mode='strict',
312
  description="High RAG depth with comprehensive detail"
313
  )
314
+
315
  else: # safety_completeness
316
  mut1 = SOPMutation(
317
  disease_explainer_k=min(10, current_sop.disease_explainer_k + 1),
 
339
  critical_value_alert_mode='strict',
340
  description="Maximum coverage with all safety features"
341
  )
342
+
343
  evolved = EvolvedSOPs(mutations=[mut1, mut2])
344
+
345
  print(f"\n✓ Generated {len(evolved.mutations)} mutations")
346
  for i, mut in enumerate(evolved.mutations, 1):
347
  print(f" {i}. {mut.description}")
348
  print(f" Disease K: {mut.disease_explainer_k}, Detail: {mut.explainer_detail_level}")
349
+
350
  return evolved
351
 
352
 
 
355
  patient_input: Any,
356
  workflow_graph: Any,
357
  evaluation_func: Callable
358
+ ) -> list[dict[str, Any]]:
359
  """
360
  Executes one complete evolution cycle:
361
  1. Diagnose current best SOP
 
368
  print("\n" + "=" * 80)
369
  print("STARTING EVOLUTION CYCLE")
370
  print("=" * 80)
371
+
372
  # Get current best (for simplicity, use latest)
373
  current_best = gene_pool.get_latest()
374
  if not current_best:
375
  raise ValueError("Gene pool is empty. Add baseline SOP first.")
376
+
377
  parent_sop = current_best['sop']
378
  parent_eval = current_best['evaluation']
379
  parent_version = current_best['version']
380
+
381
  print(f"\nImproving upon SOP v{parent_version}")
382
+
383
  # Step 1: Diagnose
384
  diagnosis = performance_diagnostician(parent_eval)
385
+
386
  # Step 2: Generate mutations
387
  evolved_sops = sop_architect(diagnosis, parent_sop)
388
+
389
  # Step 3: Test each mutation
390
  new_entries = []
391
  for i, mutant_sop_model in enumerate(evolved_sops.mutations, 1):
392
  print(f"\n{'=' * 70}")
393
  print(f"TESTING MUTATION {i}/{len(evolved_sops.mutations)}: {mutant_sop_model.description}")
394
  print("=" * 70)
395
+
396
  # Convert SOPMutation to ExplanationSOP
397
  mutant_sop_dict = mutant_sop_model.model_dump()
398
  description = mutant_sop_dict.pop('description')
399
  mutant_sop = ExplanationSOP(**mutant_sop_dict)
400
+
401
  # Run workflow with mutated SOP
 
402
  from datetime import datetime
403
  graph_input = {
404
  "patient_biomarkers": patient_input.biomarkers,
 
414
  "processing_timestamp": datetime.now().isoformat(),
415
  "sop_version": description
416
  }
417
+
418
  try:
419
  final_state = workflow_graph.invoke(graph_input)
420
+
421
  # Evaluate output
422
  evaluation = evaluation_func(
423
  final_response=final_state['final_response'],
424
  agent_outputs=final_state['agent_outputs'],
425
  biomarkers=patient_input.biomarkers
426
  )
427
+
428
  # Add to gene pool
429
  gene_pool.add(
430
  sop=mutant_sop,
 
432
  parent_version=parent_version,
433
  description=description
434
  )
435
+
436
  new_entries.append({
437
  "sop": mutant_sop,
438
  "evaluation": evaluation,
 
441
  except Exception as e:
442
  print(f"❌ Mutation {i} failed: {e}")
443
  continue
444
+
445
  print("\n" + "=" * 80)
446
  print("EVOLUTION CYCLE COMPLETE")
447
  print("=" * 80)
448
+
449
  return new_entries
archive/evolution/pareto.py CHANGED
@@ -3,14 +3,16 @@ Pareto Frontier Analysis
3
  Identifies optimal trade-offs in multi-objective optimization
4
  """
5
 
6
- import numpy as np
7
- from typing import List, Dict, Any
8
  import matplotlib
 
 
9
  matplotlib.use('Agg') # Use non-interactive backend
10
  import matplotlib.pyplot as plt
11
 
12
 
13
- def identify_pareto_front(gene_pool_entries: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
14
  """
15
  Identifies non-dominated solutions (Pareto Frontier).
16
 
@@ -19,32 +21,32 @@ def identify_pareto_front(gene_pool_entries: List[Dict[str, Any]]) -> List[Dict[
19
  - Strictly better on AT LEAST ONE metric
20
  """
21
  pareto_front = []
22
-
23
  for i, candidate in enumerate(gene_pool_entries):
24
  is_dominated = False
25
-
26
  # Get candidate's 5D score vector
27
  cand_scores = np.array(candidate['evaluation'].to_vector())
28
-
29
  for j, other in enumerate(gene_pool_entries):
30
  if i == j:
31
  continue
32
-
33
  # Get other solution's 5D vector
34
  other_scores = np.array(other['evaluation'].to_vector())
35
-
36
  # Check domination: other >= candidate on ALL, other > candidate on SOME
37
  if np.all(other_scores >= cand_scores) and np.any(other_scores > cand_scores):
38
  is_dominated = True
39
  break
40
-
41
  if not is_dominated:
42
  pareto_front.append(candidate)
43
-
44
  return pareto_front
45
 
46
 
47
- def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
48
  """
49
  Creates two visualizations:
50
  1. Parallel coordinates plot (5D)
@@ -53,16 +55,16 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
53
  if not pareto_front:
54
  print("No solutions on Pareto front to visualize")
55
  return
56
-
57
  fig = plt.figure(figsize=(18, 7))
58
-
59
  # --- Plot 1: Bar Chart (since pandas might not be available) ---
60
  ax1 = plt.subplot(1, 2, 1)
61
-
62
  metrics = ['Clinical\nAccuracy', 'Evidence\nGrounding', 'Actionability', 'Clarity', 'Safety']
63
  x = np.arange(len(metrics))
64
  width = 0.8 / len(pareto_front)
65
-
66
  for idx, entry in enumerate(pareto_front):
67
  e = entry['evaluation']
68
  scores = [
@@ -72,11 +74,11 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
72
  e.clarity.score,
73
  e.safety_completeness.score
74
  ]
75
-
76
  offset = (idx - len(pareto_front) / 2) * width + width / 2
77
  label = f"SOP v{entry['version']}"
78
  ax1.bar(x + offset, scores, width, label=label, alpha=0.8)
79
-
80
  ax1.set_xlabel('Metrics', fontsize=12)
81
  ax1.set_ylabel('Score', fontsize=12)
82
  ax1.set_title('5D Performance Comparison (Bar Chart)', fontsize=14)
@@ -85,17 +87,17 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
85
  ax1.set_ylim(0, 1.0)
86
  ax1.legend(loc='upper left')
87
  ax1.grid(True, alpha=0.3, axis='y')
88
-
89
  # --- Plot 2: Radar Chart ---
90
  ax2 = plt.subplot(1, 2, 2, projection='polar')
91
-
92
- categories = ['Clinical\nAccuracy', 'Evidence\nGrounding',
93
  'Actionability', 'Clarity', 'Safety']
94
  num_vars = len(categories)
95
-
96
  angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
97
  angles += angles[:1]
98
-
99
  for entry in pareto_front:
100
  e = entry['evaluation']
101
  values = [
@@ -106,47 +108,47 @@ def visualize_pareto_frontier(pareto_front: List[Dict[str, Any]]):
106
  e.safety_completeness.score
107
  ]
108
  values += values[:1]
109
-
110
  desc = entry.get('description', '')[:30]
111
  label = f"SOP v{entry['version']}: {desc}"
112
  ax2.plot(angles, values, 'o-', linewidth=2, label=label)
113
  ax2.fill(angles, values, alpha=0.15)
114
-
115
  ax2.set_xticks(angles[:-1])
116
  ax2.set_xticklabels(categories, size=10)
117
  ax2.set_ylim(0, 1)
118
  ax2.set_title('5D Performance Profiles (Radar Chart)', size=14, y=1.08)
119
  ax2.legend(loc='upper left', bbox_to_anchor=(1.2, 1.0), fontsize=9)
120
  ax2.grid(True)
121
-
122
  plt.tight_layout()
123
-
124
  # Create data directory if it doesn't exist
125
  from pathlib import Path
126
  data_dir = Path('data')
127
  data_dir.mkdir(exist_ok=True)
128
-
129
  output_path = data_dir / 'pareto_frontier_analysis.png'
130
  plt.savefig(output_path, dpi=300, bbox_inches='tight')
131
  plt.close()
132
-
133
  print(f"\n✓ Visualization saved to: {output_path}")
134
 
135
 
136
- def print_pareto_summary(pareto_front: List[Dict[str, Any]]):
137
  """Print human-readable summary of Pareto frontier"""
138
  print("\n" + "=" * 80)
139
  print("PARETO FRONTIER ANALYSIS")
140
  print("=" * 80)
141
-
142
  print(f"\nFound {len(pareto_front)} optimal (non-dominated) solutions:\n")
143
-
144
  for entry in pareto_front:
145
  v = entry['version']
146
  p = entry.get('parent')
147
  desc = entry.get('description', 'Baseline')
148
  e = entry['evaluation']
149
-
150
  print(f"SOP v{v} {f'(Child of v{p})' if p else '(Baseline)'}")
151
  print(f" Description: {desc}")
152
  print(f" Clinical Accuracy: {e.clinical_accuracy.score:.3f}")
@@ -154,12 +156,12 @@ def print_pareto_summary(pareto_front: List[Dict[str, Any]]):
154
  print(f" Actionability: {e.actionability.score:.3f}")
155
  print(f" Clarity: {e.clarity.score:.3f}")
156
  print(f" Safety & Completeness: {e.safety_completeness.score:.3f}")
157
-
158
  # Calculate average
159
  avg_score = np.mean(e.to_vector())
160
  print(f" Average Score: {avg_score:.3f}")
161
  print()
162
-
163
  print("=" * 80)
164
  print("\nRECOMMENDATION:")
165
  print("Review the visualizations and choose the SOP that best matches")
@@ -167,46 +169,46 @@ def print_pareto_summary(pareto_front: List[Dict[str, Any]]):
167
  print("=" * 80)
168
 
169
 
170
- def analyze_improvements(gene_pool_entries: List[Dict[str, Any]]):
171
  """Analyze improvements over baseline"""
172
  if len(gene_pool_entries) < 2:
173
  print("\n⚠️ Not enough SOPs to analyze improvements")
174
  return
175
-
176
  baseline = gene_pool_entries[0]
177
  baseline_scores = np.array(baseline['evaluation'].to_vector())
178
-
179
  print("\n" + "=" * 80)
180
  print("IMPROVEMENT ANALYSIS")
181
  print("=" * 80)
182
-
183
  print(f"\nBaseline (v{baseline['version']}): {baseline.get('description', 'Initial')}")
184
  print(f" Average Score: {np.mean(baseline_scores):.3f}")
185
-
186
  improvements_found = False
187
  for entry in gene_pool_entries[1:]:
188
  scores = np.array(entry['evaluation'].to_vector())
189
  avg_score = np.mean(scores)
190
  baseline_avg = np.mean(baseline_scores)
191
-
192
  if avg_score > baseline_avg:
193
  improvements_found = True
194
  improvement_pct = ((avg_score - baseline_avg) / baseline_avg) * 100
195
-
196
- print(f"\n✓ SOP v{entry['version']}: {entry.get('description', '')}")
197
  print(f" Average Score: {avg_score:.3f} (+{improvement_pct:.1f}% vs baseline)")
198
-
199
  # Show per-metric improvements
200
- metric_names = ['Clinical Accuracy', 'Evidence Grounding', 'Actionability',
201
  'Clarity', 'Safety & Completeness']
202
  for i, (name, score, baseline_score) in enumerate(zip(metric_names, scores, baseline_scores)):
203
  diff = score - baseline_score
204
  if abs(diff) > 0.01: # Show significant changes
205
  symbol = "↑" if diff > 0 else "↓"
206
  print(f" {name}: {score:.3f} {symbol} ({diff:+.3f})")
207
-
208
  if not improvements_found:
209
  print("\n⚠️ No improvements found over baseline yet")
210
  print(" Consider running more evolution cycles or adjusting mutation strategies")
211
-
212
  print("\n" + "=" * 80)
 
3
  Identifies optimal trade-offs in multi-objective optimization
4
  """
5
 
6
+ from typing import Any
7
+
8
  import matplotlib
9
+ import numpy as np
10
+
11
  matplotlib.use('Agg') # Use non-interactive backend
12
  import matplotlib.pyplot as plt
13
 
14
 
15
+ def identify_pareto_front(gene_pool_entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
16
  """
17
  Identifies non-dominated solutions (Pareto Frontier).
18
 
 
21
  - Strictly better on AT LEAST ONE metric
22
  """
23
  pareto_front = []
24
+
25
  for i, candidate in enumerate(gene_pool_entries):
26
  is_dominated = False
27
+
28
  # Get candidate's 5D score vector
29
  cand_scores = np.array(candidate['evaluation'].to_vector())
30
+
31
  for j, other in enumerate(gene_pool_entries):
32
  if i == j:
33
  continue
34
+
35
  # Get other solution's 5D vector
36
  other_scores = np.array(other['evaluation'].to_vector())
37
+
38
  # Check domination: other >= candidate on ALL, other > candidate on SOME
39
  if np.all(other_scores >= cand_scores) and np.any(other_scores > cand_scores):
40
  is_dominated = True
41
  break
42
+
43
  if not is_dominated:
44
  pareto_front.append(candidate)
45
+
46
  return pareto_front
47
 
48
 
49
+ def visualize_pareto_frontier(pareto_front: list[dict[str, Any]]):
50
  """
51
  Creates two visualizations:
52
  1. Parallel coordinates plot (5D)
 
55
  if not pareto_front:
56
  print("No solutions on Pareto front to visualize")
57
  return
58
+
59
  fig = plt.figure(figsize=(18, 7))
60
+
61
  # --- Plot 1: Bar Chart (since pandas might not be available) ---
62
  ax1 = plt.subplot(1, 2, 1)
63
+
64
  metrics = ['Clinical\nAccuracy', 'Evidence\nGrounding', 'Actionability', 'Clarity', 'Safety']
65
  x = np.arange(len(metrics))
66
  width = 0.8 / len(pareto_front)
67
+
68
  for idx, entry in enumerate(pareto_front):
69
  e = entry['evaluation']
70
  scores = [
 
74
  e.clarity.score,
75
  e.safety_completeness.score
76
  ]
77
+
78
  offset = (idx - len(pareto_front) / 2) * width + width / 2
79
  label = f"SOP v{entry['version']}"
80
  ax1.bar(x + offset, scores, width, label=label, alpha=0.8)
81
+
82
  ax1.set_xlabel('Metrics', fontsize=12)
83
  ax1.set_ylabel('Score', fontsize=12)
84
  ax1.set_title('5D Performance Comparison (Bar Chart)', fontsize=14)
 
87
  ax1.set_ylim(0, 1.0)
88
  ax1.legend(loc='upper left')
89
  ax1.grid(True, alpha=0.3, axis='y')
90
+
91
  # --- Plot 2: Radar Chart ---
92
  ax2 = plt.subplot(1, 2, 2, projection='polar')
93
+
94
+ categories = ['Clinical\nAccuracy', 'Evidence\nGrounding',
95
  'Actionability', 'Clarity', 'Safety']
96
  num_vars = len(categories)
97
+
98
  angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
99
  angles += angles[:1]
100
+
101
  for entry in pareto_front:
102
  e = entry['evaluation']
103
  values = [
 
108
  e.safety_completeness.score
109
  ]
110
  values += values[:1]
111
+
112
  desc = entry.get('description', '')[:30]
113
  label = f"SOP v{entry['version']}: {desc}"
114
  ax2.plot(angles, values, 'o-', linewidth=2, label=label)
115
  ax2.fill(angles, values, alpha=0.15)
116
+
117
  ax2.set_xticks(angles[:-1])
118
  ax2.set_xticklabels(categories, size=10)
119
  ax2.set_ylim(0, 1)
120
  ax2.set_title('5D Performance Profiles (Radar Chart)', size=14, y=1.08)
121
  ax2.legend(loc='upper left', bbox_to_anchor=(1.2, 1.0), fontsize=9)
122
  ax2.grid(True)
123
+
124
  plt.tight_layout()
125
+
126
  # Create data directory if it doesn't exist
127
  from pathlib import Path
128
  data_dir = Path('data')
129
  data_dir.mkdir(exist_ok=True)
130
+
131
  output_path = data_dir / 'pareto_frontier_analysis.png'
132
  plt.savefig(output_path, dpi=300, bbox_inches='tight')
133
  plt.close()
134
+
135
  print(f"\n✓ Visualization saved to: {output_path}")
136
 
137
 
138
+ def print_pareto_summary(pareto_front: list[dict[str, Any]]):
139
  """Print human-readable summary of Pareto frontier"""
140
  print("\n" + "=" * 80)
141
  print("PARETO FRONTIER ANALYSIS")
142
  print("=" * 80)
143
+
144
  print(f"\nFound {len(pareto_front)} optimal (non-dominated) solutions:\n")
145
+
146
  for entry in pareto_front:
147
  v = entry['version']
148
  p = entry.get('parent')
149
  desc = entry.get('description', 'Baseline')
150
  e = entry['evaluation']
151
+
152
  print(f"SOP v{v} {f'(Child of v{p})' if p else '(Baseline)'}")
153
  print(f" Description: {desc}")
154
  print(f" Clinical Accuracy: {e.clinical_accuracy.score:.3f}")
 
156
  print(f" Actionability: {e.actionability.score:.3f}")
157
  print(f" Clarity: {e.clarity.score:.3f}")
158
  print(f" Safety & Completeness: {e.safety_completeness.score:.3f}")
159
+
160
  # Calculate average
161
  avg_score = np.mean(e.to_vector())
162
  print(f" Average Score: {avg_score:.3f}")
163
  print()
164
+
165
  print("=" * 80)
166
  print("\nRECOMMENDATION:")
167
  print("Review the visualizations and choose the SOP that best matches")
 
169
  print("=" * 80)
170
 
171
 
172
+ def analyze_improvements(gene_pool_entries: list[dict[str, Any]]):
173
  """Analyze improvements over baseline"""
174
  if len(gene_pool_entries) < 2:
175
  print("\n⚠️ Not enough SOPs to analyze improvements")
176
  return
177
+
178
  baseline = gene_pool_entries[0]
179
  baseline_scores = np.array(baseline['evaluation'].to_vector())
180
+
181
  print("\n" + "=" * 80)
182
  print("IMPROVEMENT ANALYSIS")
183
  print("=" * 80)
184
+
185
  print(f"\nBaseline (v{baseline['version']}): {baseline.get('description', 'Initial')}")
186
  print(f" Average Score: {np.mean(baseline_scores):.3f}")
187
+
188
  improvements_found = False
189
  for entry in gene_pool_entries[1:]:
190
  scores = np.array(entry['evaluation'].to_vector())
191
  avg_score = np.mean(scores)
192
  baseline_avg = np.mean(baseline_scores)
193
+
194
  if avg_score > baseline_avg:
195
  improvements_found = True
196
  improvement_pct = ((avg_score - baseline_avg) / baseline_avg) * 100
197
+
198
+ print(f"\n✓ SOP v{entry['version']}: {entry.get('description', '')}")
199
  print(f" Average Score: {avg_score:.3f} (+{improvement_pct:.1f}% vs baseline)")
200
+
201
  # Show per-metric improvements
202
+ metric_names = ['Clinical Accuracy', 'Evidence Grounding', 'Actionability',
203
  'Clarity', 'Safety & Completeness']
204
  for i, (name, score, baseline_score) in enumerate(zip(metric_names, scores, baseline_scores)):
205
  diff = score - baseline_score
206
  if abs(diff) > 0.01: # Show significant changes
207
  symbol = "↑" if diff > 0 else "↓"
208
  print(f" {name}: {score:.3f} {symbol} ({diff:+.3f})")
209
+
210
  if not improvements_found:
211
  print("\n⚠️ No improvements found over baseline yet")
212
  print(" Consider running more evolution cycles or adjusting mutation strategies")
213
+
214
  print("\n" + "=" * 80)
archive/sop_evolution.py CHANGED
@@ -8,9 +8,10 @@ from __future__ import annotations
8
 
9
  from datetime import datetime, timedelta
10
 
11
- from airflow import DAG
12
  from airflow.operators.python import PythonOperator
13
 
 
 
14
  default_args = {
15
  "owner": "mediguard",
16
  "retries": 1,
 
8
 
9
  from datetime import datetime, timedelta
10
 
 
11
  from airflow.operators.python import PythonOperator
12
 
13
+ from airflow import DAG
14
+
15
  default_args = {
16
  "owner": "mediguard",
17
  "retries": 1,
{tests → archive/tests}/test_evolution_loop.py RENAMED
@@ -10,20 +10,20 @@ from pathlib import Path
10
  project_root = Path(__file__).parent.parent
11
  sys.path.insert(0, str(project_root))
12
 
13
- from src.workflow import create_guild
14
- from src.pdf_processor import get_all_retrievers
 
15
  from src.config import BASELINE_SOP
16
- from src.state import PatientInput, GuildState
17
  from src.evaluation.evaluators import run_full_evaluation
18
  from src.evolution.director import SOPGenePool, run_evolution_cycle
19
  from src.evolution.pareto import (
 
20
  identify_pareto_front,
21
- visualize_pareto_frontier,
22
  print_pareto_summary,
23
- analyze_improvements
24
  )
25
- from datetime import datetime
26
- from typing import Dict, Any
27
 
28
 
29
  def create_test_patient() -> PatientInput:
@@ -53,8 +53,8 @@ def create_test_patient() -> PatientInput:
53
  "Chloride": 102.0,
54
  "Bicarbonate": 24.0
55
  }
56
-
57
- model_prediction: Dict[str, Any] = {
58
  'disease': 'Type 2 Diabetes',
59
  'confidence': 0.92,
60
  'probabilities': {
@@ -64,7 +64,7 @@ def create_test_patient() -> PatientInput:
64
  },
65
  'prediction_timestamp': '2025-01-01T10:00:00'
66
  }
67
-
68
  patient_context = {
69
  'patient_id': 'TEST-001',
70
  'age': 55,
@@ -74,7 +74,7 @@ def create_test_patient() -> PatientInput:
74
  'current_medications': ["Metformin 500mg"],
75
  'query': "My blood sugar has been high lately. What should I do?"
76
  }
77
-
78
  return PatientInput(
79
  biomarkers=biomarkers,
80
  model_prediction=model_prediction,
@@ -87,19 +87,19 @@ def main():
87
  print("\n" + "=" * 80)
88
  print("PHASE 3: SELF-IMPROVEMENT LOOP TEST")
89
  print("=" * 80)
90
-
91
  # Setup
92
  print("\n1. Initializing system...")
93
  guild = create_guild()
94
  patient = create_test_patient()
95
-
96
  # Initialize gene pool with baseline
97
  print("\n2. Creating SOP Gene Pool...")
98
  gene_pool = SOPGenePool()
99
-
100
  print("\n3. Evaluating Baseline SOP...")
101
  # Run workflow with baseline SOP
102
-
103
  initial_state: GuildState = {
104
  'patient_biomarkers': patient.biomarkers,
105
  'model_prediction': patient.model_prediction,
@@ -113,41 +113,41 @@ def main():
113
  'processing_timestamp': datetime.now().isoformat(),
114
  'sop_version': "Baseline"
115
  }
116
-
117
  guild_state = guild.workflow.invoke(initial_state)
118
-
119
  baseline_response = guild_state['final_response']
120
  agent_outputs = guild_state['agent_outputs']
121
-
122
  baseline_eval = run_full_evaluation(
123
  final_response=baseline_response,
124
  agent_outputs=agent_outputs,
125
  biomarkers=patient.biomarkers
126
  )
127
-
128
  gene_pool.add(
129
  sop=BASELINE_SOP,
130
  evaluation=baseline_eval,
131
  parent_version=None,
132
  description="Baseline SOP"
133
  )
134
-
135
  print(f"\n✓ Baseline Average Score: {baseline_eval.average_score():.3f}")
136
  print(f" Clinical Accuracy: {baseline_eval.clinical_accuracy.score:.3f}")
137
  print(f" Evidence Grounding: {baseline_eval.evidence_grounding.score:.3f}")
138
  print(f" Actionability: {baseline_eval.actionability.score:.3f}")
139
  print(f" Clarity: {baseline_eval.clarity.score:.3f}")
140
  print(f" Safety & Completeness: {baseline_eval.safety_completeness.score:.3f}")
141
-
142
  # Run evolution cycles
143
  num_cycles = 2
144
  print(f"\n4. Running {num_cycles} Evolution Cycles...")
145
-
146
  for cycle in range(1, num_cycles + 1):
147
  print(f"\n{'─' * 80}")
148
  print(f"EVOLUTION CYCLE {cycle}")
149
  print(f"{'─' * 80}")
150
-
151
  try:
152
  # Create evaluation function for this cycle
153
  def eval_func(final_response, agent_outputs, biomarkers):
@@ -156,61 +156,61 @@ def main():
156
  agent_outputs=agent_outputs,
157
  biomarkers=biomarkers
158
  )
159
-
160
  new_entries = run_evolution_cycle(
161
  gene_pool=gene_pool,
162
  patient_input=patient,
163
  workflow_graph=guild.workflow,
164
  evaluation_func=eval_func
165
  )
166
-
167
  print(f"\n✓ Cycle {cycle} complete: Added {len(new_entries)} new SOPs to gene pool")
168
-
169
  for entry in new_entries:
170
  print(f"\n SOP v{entry['version']}: {entry['description']}")
171
  print(f" Average Score: {entry['evaluation'].average_score():.3f}")
172
-
173
  except Exception as e:
174
  print(f"\n⚠️ Cycle {cycle} encountered error: {e}")
175
  print("Continuing to next cycle...")
176
-
177
  # Show gene pool summary
178
  print("\n5. Gene Pool Summary:")
179
  gene_pool.summary()
180
-
181
  # Pareto Analysis
182
  print("\n6. Identifying Pareto Frontier...")
183
  all_entries = gene_pool.gene_pool
184
  pareto_front = identify_pareto_front(all_entries)
185
-
186
  print(f"\n✓ Pareto frontier contains {len(pareto_front)} non-dominated solutions")
187
  print_pareto_summary(pareto_front)
188
-
189
  # Improvement Analysis
190
  print("\n7. Analyzing Improvements...")
191
  analyze_improvements(all_entries)
192
-
193
  # Visualizations
194
  print("\n8. Generating Visualizations...")
195
  visualize_pareto_frontier(pareto_front)
196
-
197
  # Final Summary
198
  print("\n" + "=" * 80)
199
  print("EVOLUTION TEST COMPLETE")
200
  print("=" * 80)
201
-
202
  print(f"\n✓ Total SOPs in Gene Pool: {len(all_entries)}")
203
  print(f"✓ Pareto Optimal SOPs: {len(pareto_front)}")
204
-
205
  # Find best average score
206
  best_sop = max(all_entries, key=lambda e: e['evaluation'].average_score())
207
  baseline_avg = baseline_eval.average_score()
208
  best_avg = best_sop['evaluation'].average_score()
209
  improvement = ((best_avg - baseline_avg) / baseline_avg) * 100
210
-
211
  print(f"\nBest SOP: v{best_sop['version']} - {best_sop['description']}")
212
  print(f" Average Score: {best_avg:.3f} ({improvement:+.1f}% vs baseline)")
213
-
214
  print("\n✓ Visualization saved to: data/pareto_frontier_analysis.png")
215
  print("\n" + "=" * 80)
216
 
 
10
  project_root = Path(__file__).parent.parent
11
  sys.path.insert(0, str(project_root))
12
 
13
+ from datetime import datetime
14
+ from typing import Any
15
+
16
  from src.config import BASELINE_SOP
 
17
  from src.evaluation.evaluators import run_full_evaluation
18
  from src.evolution.director import SOPGenePool, run_evolution_cycle
19
  from src.evolution.pareto import (
20
+ analyze_improvements,
21
  identify_pareto_front,
 
22
  print_pareto_summary,
23
+ visualize_pareto_frontier,
24
  )
25
+ from src.state import GuildState, PatientInput
26
+ from src.workflow import create_guild
27
 
28
 
29
  def create_test_patient() -> PatientInput:
 
53
  "Chloride": 102.0,
54
  "Bicarbonate": 24.0
55
  }
56
+
57
+ model_prediction: dict[str, Any] = {
58
  'disease': 'Type 2 Diabetes',
59
  'confidence': 0.92,
60
  'probabilities': {
 
64
  },
65
  'prediction_timestamp': '2025-01-01T10:00:00'
66
  }
67
+
68
  patient_context = {
69
  'patient_id': 'TEST-001',
70
  'age': 55,
 
74
  'current_medications': ["Metformin 500mg"],
75
  'query': "My blood sugar has been high lately. What should I do?"
76
  }
77
+
78
  return PatientInput(
79
  biomarkers=biomarkers,
80
  model_prediction=model_prediction,
 
87
  print("\n" + "=" * 80)
88
  print("PHASE 3: SELF-IMPROVEMENT LOOP TEST")
89
  print("=" * 80)
90
+
91
  # Setup
92
  print("\n1. Initializing system...")
93
  guild = create_guild()
94
  patient = create_test_patient()
95
+
96
  # Initialize gene pool with baseline
97
  print("\n2. Creating SOP Gene Pool...")
98
  gene_pool = SOPGenePool()
99
+
100
  print("\n3. Evaluating Baseline SOP...")
101
  # Run workflow with baseline SOP
102
+
103
  initial_state: GuildState = {
104
  'patient_biomarkers': patient.biomarkers,
105
  'model_prediction': patient.model_prediction,
 
113
  'processing_timestamp': datetime.now().isoformat(),
114
  'sop_version': "Baseline"
115
  }
116
+
117
  guild_state = guild.workflow.invoke(initial_state)
118
+
119
  baseline_response = guild_state['final_response']
120
  agent_outputs = guild_state['agent_outputs']
121
+
122
  baseline_eval = run_full_evaluation(
123
  final_response=baseline_response,
124
  agent_outputs=agent_outputs,
125
  biomarkers=patient.biomarkers
126
  )
127
+
128
  gene_pool.add(
129
  sop=BASELINE_SOP,
130
  evaluation=baseline_eval,
131
  parent_version=None,
132
  description="Baseline SOP"
133
  )
134
+
135
  print(f"\n✓ Baseline Average Score: {baseline_eval.average_score():.3f}")
136
  print(f" Clinical Accuracy: {baseline_eval.clinical_accuracy.score:.3f}")
137
  print(f" Evidence Grounding: {baseline_eval.evidence_grounding.score:.3f}")
138
  print(f" Actionability: {baseline_eval.actionability.score:.3f}")
139
  print(f" Clarity: {baseline_eval.clarity.score:.3f}")
140
  print(f" Safety & Completeness: {baseline_eval.safety_completeness.score:.3f}")
141
+
142
  # Run evolution cycles
143
  num_cycles = 2
144
  print(f"\n4. Running {num_cycles} Evolution Cycles...")
145
+
146
  for cycle in range(1, num_cycles + 1):
147
  print(f"\n{'─' * 80}")
148
  print(f"EVOLUTION CYCLE {cycle}")
149
  print(f"{'─' * 80}")
150
+
151
  try:
152
  # Create evaluation function for this cycle
153
  def eval_func(final_response, agent_outputs, biomarkers):
 
156
  agent_outputs=agent_outputs,
157
  biomarkers=biomarkers
158
  )
159
+
160
  new_entries = run_evolution_cycle(
161
  gene_pool=gene_pool,
162
  patient_input=patient,
163
  workflow_graph=guild.workflow,
164
  evaluation_func=eval_func
165
  )
166
+
167
  print(f"\n✓ Cycle {cycle} complete: Added {len(new_entries)} new SOPs to gene pool")
168
+
169
  for entry in new_entries:
170
  print(f"\n SOP v{entry['version']}: {entry['description']}")
171
  print(f" Average Score: {entry['evaluation'].average_score():.3f}")
172
+
173
  except Exception as e:
174
  print(f"\n⚠️ Cycle {cycle} encountered error: {e}")
175
  print("Continuing to next cycle...")
176
+
177
  # Show gene pool summary
178
  print("\n5. Gene Pool Summary:")
179
  gene_pool.summary()
180
+
181
  # Pareto Analysis
182
  print("\n6. Identifying Pareto Frontier...")
183
  all_entries = gene_pool.gene_pool
184
  pareto_front = identify_pareto_front(all_entries)
185
+
186
  print(f"\n✓ Pareto frontier contains {len(pareto_front)} non-dominated solutions")
187
  print_pareto_summary(pareto_front)
188
+
189
  # Improvement Analysis
190
  print("\n7. Analyzing Improvements...")
191
  analyze_improvements(all_entries)
192
+
193
  # Visualizations
194
  print("\n8. Generating Visualizations...")
195
  visualize_pareto_frontier(pareto_front)
196
+
197
  # Final Summary
198
  print("\n" + "=" * 80)
199
  print("EVOLUTION TEST COMPLETE")
200
  print("=" * 80)
201
+
202
  print(f"\n✓ Total SOPs in Gene Pool: {len(all_entries)}")
203
  print(f"✓ Pareto Optimal SOPs: {len(pareto_front)}")
204
+
205
  # Find best average score
206
  best_sop = max(all_entries, key=lambda e: e['evaluation'].average_score())
207
  baseline_avg = baseline_eval.average_score()
208
  best_avg = best_sop['evaluation'].average_score()
209
  improvement = ((best_avg - baseline_avg) / baseline_avg) * 100
210
+
211
  print(f"\nBest SOP: v{best_sop['version']} - {best_sop['description']}")
212
  print(f" Average Score: {best_avg:.3f} ({improvement:+.1f}% vs baseline)")
213
+
214
  print("\n✓ Visualization saved to: data/pareto_frontier_analysis.png")
215
  print("\n" + "=" * 80)
216
 
{tests → archive/tests}/test_evolution_quick.py RENAMED
@@ -5,6 +5,7 @@ Tests gene pool, diagnostician, and architect without full workflow
5
 
6
  import sys
7
  from pathlib import Path
 
8
  sys.path.insert(0, str(Path(__file__).parent.parent))
9
 
10
  from src.config import BASELINE_SOP
@@ -17,11 +18,11 @@ def main():
17
  print("\n" + "=" * 80)
18
  print("QUICK PHASE 3 TEST")
19
  print("=" * 80)
20
-
21
  # Test 1: Gene Pool
22
  print("\n1. Testing Gene Pool...")
23
  gene_pool = SOPGenePool()
24
-
25
  # Create mock evaluation (baseline with low clarity)
26
  baseline_eval = EvaluationResult(
27
  clinical_accuracy=GradedScore(score=0.95, reasoning="Accurate"),
@@ -30,48 +31,48 @@ def main():
30
  clarity=GradedScore(score=0.75, reasoning="Could be clearer"),
31
  safety_completeness=GradedScore(score=1.0, reasoning="Complete")
32
  )
33
-
34
  gene_pool.add(
35
  sop=BASELINE_SOP,
36
  evaluation=baseline_eval,
37
  parent_version=None,
38
  description="Baseline SOP"
39
  )
40
-
41
- print(f"✓ Gene pool initialized with 1 SOP")
42
  print(f" Average score: {baseline_eval.average_score():.3f}")
43
-
44
  # Test 2: Performance Diagnostician
45
  print("\n2. Testing Performance Diagnostician...")
46
  diagnosis = performance_diagnostician(baseline_eval)
47
-
48
- print(f"✓ Diagnosis complete")
49
  print(f" Primary weakness: {diagnosis.primary_weakness}")
50
  print(f" Root cause: {diagnosis.root_cause_analysis[:100]}...")
51
  print(f" Recommendation: {diagnosis.recommendation[:100]}...")
52
-
53
  # Test 3: SOP Architect
54
  print("\n3. Testing SOP Architect...")
55
  evolved_sops = sop_architect(diagnosis, BASELINE_SOP)
56
-
57
  print(f"\n✓ Generated {len(evolved_sops.mutations)} mutations")
58
  for i, mutation in enumerate(evolved_sops.mutations, 1):
59
  print(f"\n Mutation {i}: {mutation.description}")
60
  print(f" Disease explainer K: {mutation.disease_explainer_k}")
61
  print(f" Detail level: {mutation.explainer_detail_level}")
62
  print(f" Citations required: {mutation.require_pdf_citations}")
63
-
64
  # Test 4: Gene Pool Summary
65
  print("\n4. Gene Pool Summary:")
66
  gene_pool.summary()
67
-
68
  # Test 5: Average score method
69
  print("\n5. Testing average_score method...")
70
  avg = baseline_eval.average_score()
71
  print(f"✓ Average score calculation: {avg:.3f}")
72
  vector = baseline_eval.to_vector()
73
  print(f"✓ Score vector: {[f'{s:.2f}' for s in vector]}")
74
-
75
  print("\n" + "=" * 80)
76
  print("QUICK TEST COMPLETE")
77
  print("=" * 80)
 
5
 
6
  import sys
7
  from pathlib import Path
8
+
9
  sys.path.insert(0, str(Path(__file__).parent.parent))
10
 
11
  from src.config import BASELINE_SOP
 
18
  print("\n" + "=" * 80)
19
  print("QUICK PHASE 3 TEST")
20
  print("=" * 80)
21
+
22
  # Test 1: Gene Pool
23
  print("\n1. Testing Gene Pool...")
24
  gene_pool = SOPGenePool()
25
+
26
  # Create mock evaluation (baseline with low clarity)
27
  baseline_eval = EvaluationResult(
28
  clinical_accuracy=GradedScore(score=0.95, reasoning="Accurate"),
 
31
  clarity=GradedScore(score=0.75, reasoning="Could be clearer"),
32
  safety_completeness=GradedScore(score=1.0, reasoning="Complete")
33
  )
34
+
35
  gene_pool.add(
36
  sop=BASELINE_SOP,
37
  evaluation=baseline_eval,
38
  parent_version=None,
39
  description="Baseline SOP"
40
  )
41
+
42
+ print("✓ Gene pool initialized with 1 SOP")
43
  print(f" Average score: {baseline_eval.average_score():.3f}")
44
+
45
  # Test 2: Performance Diagnostician
46
  print("\n2. Testing Performance Diagnostician...")
47
  diagnosis = performance_diagnostician(baseline_eval)
48
+
49
+ print("✓ Diagnosis complete")
50
  print(f" Primary weakness: {diagnosis.primary_weakness}")
51
  print(f" Root cause: {diagnosis.root_cause_analysis[:100]}...")
52
  print(f" Recommendation: {diagnosis.recommendation[:100]}...")
53
+
54
  # Test 3: SOP Architect
55
  print("\n3. Testing SOP Architect...")
56
  evolved_sops = sop_architect(diagnosis, BASELINE_SOP)
57
+
58
  print(f"\n✓ Generated {len(evolved_sops.mutations)} mutations")
59
  for i, mutation in enumerate(evolved_sops.mutations, 1):
60
  print(f"\n Mutation {i}: {mutation.description}")
61
  print(f" Disease explainer K: {mutation.disease_explainer_k}")
62
  print(f" Detail level: {mutation.explainer_detail_level}")
63
  print(f" Citations required: {mutation.require_pdf_citations}")
64
+
65
  # Test 4: Gene Pool Summary
66
  print("\n4. Gene Pool Summary:")
67
  gene_pool.summary()
68
+
69
  # Test 5: Average score method
70
  print("\n5. Testing average_score method...")
71
  avg = baseline_eval.average_score()
72
  print(f"✓ Average score calculation: {avg:.3f}")
73
  vector = baseline_eval.to_vector()
74
  print(f"✓ Score vector: {[f'{s:.2f}' for s in vector]}")
75
+
76
  print("\n" + "=" * 80)
77
  print("QUICK TEST COMPLETE")
78
  print("=" * 80)
docker-compose.yml CHANGED
@@ -143,6 +143,26 @@ services:
143
  # count: 1
144
  # capabilities: [gpu]
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # -----------------------------------------------------------------------
147
  # Observability
148
  # -----------------------------------------------------------------------
 
143
  # count: 1
144
  # capabilities: [gpu]
145
 
146
+ airflow:
147
+ image: apache/airflow:2.8.2
148
+ container_name: mediguard-airflow
149
+ environment:
150
+ - AIRFLOW__CORE__LOAD_EXAMPLES=false
151
+ - AIRFLOW__CORE__EXECUTOR=LocalExecutor
152
+ - AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://${POSTGRES__USER:-mediguard}:${POSTGRES__PASSWORD:-mediguard_secret}@postgres:5432/${POSTGRES__DATABASE:-mediguard}
153
+ command: standalone
154
+ ports:
155
+ - "${AIRFLOW_PORT:-8080}:8080"
156
+ volumes:
157
+ - ./airflow/dags:/opt/airflow/dags:ro
158
+ - ./data/medical_pdfs:/app/data/medical_pdfs:ro
159
+ - .:/app:ro
160
+ working_dir: /app
161
+ depends_on:
162
+ postgres:
163
+ condition: service_healthy
164
+ restart: unless-stopped
165
+
166
  # -----------------------------------------------------------------------
167
  # Observability
168
  # -----------------------------------------------------------------------
gradio_launcher.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediGuard AI — Gradio Launcher wrapper.
3
+
4
+ Spawns the Gradio frontend UI on the correct designated port (7861), separating
5
+ the frontend runner from the production API layer entirely.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+
12
+ # Ensure project root is in path
13
+ from pathlib import Path
14
+
15
+ sys.path.insert(0, str(Path(__file__).parent))
16
+
17
+ from src.gradio_app import launch_gradio
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+
21
+ if __name__ == "__main__":
22
+ port = int(os.environ.get("GRADIO_PORT", 7861))
23
+ logging.info("Starting Gradio Web UI Launcher on port %d...", port)
24
+ launch_gradio(share=False, server_port=port)
huggingface/app.py CHANGED
@@ -37,7 +37,7 @@ import sys
37
  import time
38
  import traceback
39
  from pathlib import Path
40
- from typing import Any, Optional
41
 
42
  # Ensure project root is in path
43
  _project_root = str(Path(__file__).parent.parent)
@@ -114,7 +114,7 @@ def setup_llm_provider():
114
  """
115
  groq_key, google_key = get_api_keys()
116
  provider = None
117
-
118
  if groq_key:
119
  os.environ["LLM_PROVIDER"] = "groq"
120
  os.environ["GROQ_API_KEY"] = groq_key
@@ -127,18 +127,18 @@ def setup_llm_provider():
127
  os.environ["GEMINI_MODEL"] = get_gemini_model()
128
  provider = "gemini"
129
  logger.info(f"Configured Gemini provider with model: {get_gemini_model()}")
130
-
131
  # Set up embedding provider
132
  embedding_provider = get_embedding_provider()
133
  os.environ["EMBEDDING_PROVIDER"] = embedding_provider
134
-
135
  # If Jina is configured, set the API key
136
  jina_key = get_jina_api_key()
137
  if jina_key:
138
  os.environ["JINA_API_KEY"] = jina_key
139
  os.environ["EMBEDDING__JINA_API_KEY"] = jina_key
140
  logger.info("Jina embeddings configured")
141
-
142
  # Set up Langfuse if enabled
143
  if is_langfuse_enabled():
144
  os.environ["LANGFUSE__ENABLED"] = "true"
@@ -147,7 +147,7 @@ def setup_llm_provider():
147
  if val:
148
  os.environ[var] = val
149
  logger.info("Langfuse observability enabled")
150
-
151
  return provider
152
 
153
 
@@ -192,21 +192,21 @@ def reset_guild():
192
  def get_guild():
193
  """Lazy initialization of the Clinical Insight Guild."""
194
  global _guild, _guild_error, _guild_provider
195
-
196
  # Check if we need to reinitialize (provider changed)
197
  current_provider = os.getenv("LLM_PROVIDER")
198
  if _guild_provider and _guild_provider != current_provider:
199
  logger.info(f"Provider changed from {_guild_provider} to {current_provider}, reinitializing...")
200
  reset_guild()
201
-
202
  if _guild is not None:
203
  return _guild
204
-
205
  if _guild_error is not None:
206
  # Don't cache errors forever - allow retry
207
  logger.warning("Previous initialization failed, retrying...")
208
  _guild_error = None
209
-
210
  try:
211
  logger.info("Initializing Clinical Insight Guild...")
212
  logger.info(f" LLM_PROVIDER: {os.getenv('LLM_PROVIDER', 'not set')}")
@@ -214,17 +214,17 @@ def get_guild():
214
  logger.info(f" GOOGLE_API_KEY: {'✓ set' if os.getenv('GOOGLE_API_KEY') else '✗ not set'}")
215
  logger.info(f" EMBEDDING_PROVIDER: {os.getenv('EMBEDDING_PROVIDER', 'huggingface')}")
216
  logger.info(f" JINA_API_KEY: {'✓ set' if os.getenv('JINA_API_KEY') else '✗ not set'}")
217
-
218
  start = time.time()
219
-
220
  from src.workflow import create_guild
221
  _guild = create_guild()
222
  _guild_provider = current_provider
223
-
224
  elapsed = time.time() - start
225
  logger.info(f"Guild initialized in {elapsed:.1f}s")
226
  return _guild
227
-
228
  except Exception as exc:
229
  logger.error(f"Failed to initialize guild: {exc}")
230
  _guild_error = exc
@@ -237,11 +237,8 @@ def get_guild():
237
 
238
  # Import shared parsing and prediction logic
239
  from src.shared_utils import (
240
- parse_biomarkers,
241
  get_primary_prediction,
242
- flag_biomarkers,
243
- severity_to_emoji,
244
- format_confidence_percent,
245
  )
246
 
247
 
@@ -267,10 +264,10 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
267
  <p style="margin: 8px 0 0 0; color: #64748b;">Please enter biomarkers to analyze.</p>
268
  </div>
269
  """
270
-
271
  # Check API key dynamically (HF injects secrets after startup)
272
  groq_key, google_key = get_api_keys()
273
-
274
  if not groq_key and not google_key:
275
  return "", "", """
276
  <div style="background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); border: 1px solid #ef4444; border-radius: 10px; padding: 16px;">
@@ -297,15 +294,15 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
297
  </details>
298
  </div>
299
  """
300
-
301
  # Setup provider based on available key
302
  provider = setup_llm_provider()
303
  logger.info(f"Using LLM provider: {provider}")
304
-
305
  try:
306
  progress(0.1, desc="📝 Parsing biomarkers...")
307
  biomarkers = parse_biomarkers(input_text)
308
-
309
  if not biomarkers:
310
  return "", "", """
311
  <div style="background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); border: 1px solid #fbbf24; border-radius: 10px; padding: 16px;">
@@ -317,42 +314,42 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
317
  </ul>
318
  </div>
319
  """
320
-
321
  progress(0.2, desc="🔧 Initializing AI agents...")
322
-
323
  # Initialize guild
324
  guild = get_guild()
325
-
326
  # Prepare input
327
  from src.state import PatientInput
328
-
329
  # Auto-generate prediction based on common patterns
330
  prediction = auto_predict(biomarkers)
331
-
332
  patient_input = PatientInput(
333
  biomarkers=biomarkers,
334
  model_prediction=prediction,
335
  patient_context={"patient_id": "HF_User", "source": "huggingface_spaces"}
336
  )
337
-
338
  progress(0.4, desc="🤖 Running Clinical Insight Guild...")
339
-
340
  # Run analysis
341
  start = time.time()
342
  result = guild.run(patient_input)
343
  elapsed = time.time() - start
344
-
345
  progress(0.9, desc="✨ Formatting results...")
346
-
347
  # Extract response
348
  final_response = result.get("final_response", {})
349
-
350
  # Format summary
351
  summary = format_summary(final_response, elapsed)
352
-
353
  # Format details
354
  details = json.dumps(final_response, indent=2, default=str)
355
-
356
  status = f"""
357
  <div style="background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%); border: 1px solid #10b981; border-radius: 10px; padding: 12px; display: flex; align-items: center; gap: 10px;">
358
  <span style="font-size: 1.5em;">✅</span>
@@ -362,9 +359,9 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
362
  </div>
363
  </div>
364
  """
365
-
366
  return summary, details, status
367
-
368
  except Exception as exc:
369
  logger.error(f"Analysis error: {exc}", exc_info=True)
370
  error_msg = f"""
@@ -384,14 +381,14 @@ def format_summary(response: dict, elapsed: float) -> str:
384
  """Format the analysis response as clean markdown with black text."""
385
  if not response:
386
  return "❌ **No analysis results available.**"
387
-
388
  parts = []
389
-
390
  # Header with primary finding and confidence
391
  primary = response.get("primary_finding", "Analysis Complete")
392
  confidence = response.get("confidence", {})
393
  conf_score = confidence.get("overall_score", 0) if isinstance(confidence, dict) else 0
394
-
395
  # Determine severity
396
  severity = response.get("severity", "low")
397
  severity_config = {
@@ -401,14 +398,14 @@ def format_summary(response: dict, elapsed: float) -> str:
401
  "low": ("🟢", "#16a34a", "#f0fdf4")
402
  }
403
  emoji, color, bg_color = severity_config.get(severity, severity_config["low"])
404
-
405
  # Build confidence display
406
  conf_badge = ""
407
  if conf_score:
408
  conf_pct = int(conf_score * 100)
409
  conf_color = "#16a34a" if conf_pct >= 80 else "#ca8a04" if conf_pct >= 60 else "#dc2626"
410
  conf_badge = f'<span style="background: {conf_color}; color: white; padding: 4px 12px; border-radius: 20px; font-size: 0.85em; margin-left: 12px;">{conf_pct}% confidence</span>'
411
-
412
  parts.append(f"""
413
  <div style="background: linear-gradient(135deg, {bg_color} 0%, white 100%); border-left: 4px solid {color}; border-radius: 12px; padding: 20px; margin-bottom: 20px;">
414
  <div style="display: flex; align-items: center; flex-wrap: wrap;">
@@ -417,7 +414,7 @@ def format_summary(response: dict, elapsed: float) -> str:
417
  {conf_badge}
418
  </div>
419
  </div>""")
420
-
421
  # Critical Alerts
422
  alerts = response.get("safety_alerts", [])
423
  if alerts:
@@ -427,7 +424,7 @@ def format_summary(response: dict, elapsed: float) -> str:
427
  alert_items += f'<li><strong>{alert.get("alert_type", "Alert")}:</strong> {alert.get("message", "")}</li>'
428
  else:
429
  alert_items += f'<li>{alert}</li>'
430
-
431
  parts.append(f"""
432
  <div style="background: linear-gradient(135deg, #fef2f2 0%, #fee2e2 100%); border: 1px solid #fecaca; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
433
  <h4 style="margin: 0 0 12px 0; color: #dc2626; display: flex; align-items: center; gap: 8px;">
@@ -436,7 +433,7 @@ def format_summary(response: dict, elapsed: float) -> str:
436
  <ul style="margin: 0; padding-left: 20px; color: #991b1b;">{alert_items}</ul>
437
  </div>
438
  """)
439
-
440
  # Key Findings
441
  findings = response.get("key_findings", [])
442
  if findings:
@@ -447,7 +444,7 @@ def format_summary(response: dict, elapsed: float) -> str:
447
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{finding_items}</ul>
448
  </div>
449
  """)
450
-
451
  # Biomarker Flags - as a visual grid
452
  flags = response.get("biomarker_flags", [])
453
  if flags and len(flags) > 0:
@@ -460,7 +457,7 @@ def format_summary(response: dict, elapsed: float) -> str:
460
  continue
461
  status = flag.get("status", "normal").lower()
462
  value = flag.get("value", flag.get("result", "N/A"))
463
-
464
  status_styles = {
465
  "critical": ("🔴", "#dc2626", "#fef2f2"),
466
  "high": ("🔴", "#dc2626", "#fef2f2"),
@@ -469,7 +466,7 @@ def format_summary(response: dict, elapsed: float) -> str:
469
  "normal": ("🟢", "#16a34a", "#f0fdf4")
470
  }
471
  s_emoji, s_color, s_bg = status_styles.get(status, status_styles["normal"])
472
-
473
  flag_cards += f"""
474
  <div style="background: {s_bg}; border: 1px solid {s_color}33; border-radius: 8px; padding: 12px; text-align: center;">
475
  <div style="font-size: 1.2em;">{s_emoji}</div>
@@ -478,7 +475,7 @@ def format_summary(response: dict, elapsed: float) -> str:
478
  <div style="font-size: 0.75em; color: #64748b; text-transform: capitalize;">{status}</div>
479
  </div>
480
  """
481
-
482
  if flag_cards: # Only show section if we have cards
483
  parts.append(f"""
484
  <div style="margin-bottom: 16px;">
@@ -488,11 +485,11 @@ def format_summary(response: dict, elapsed: float) -> str:
488
  </div>
489
  </div>
490
  """)
491
-
492
  # Recommendations - organized sections
493
  recs = response.get("recommendations", {})
494
  rec_sections = ""
495
-
496
  immediate = recs.get("immediate_actions", []) if isinstance(recs, dict) else []
497
  if immediate and len(immediate) > 0:
498
  items = "".join([f'<li style="margin-bottom: 6px;">{str(a).strip()}</li>' for a in immediate[:3]])
@@ -502,7 +499,7 @@ def format_summary(response: dict, elapsed: float) -> str:
502
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
503
  </div>
504
  """
505
-
506
  lifestyle = recs.get("lifestyle_modifications", []) if isinstance(recs, dict) else []
507
  if lifestyle and len(lifestyle) > 0:
508
  items = "".join([f'<li style="margin-bottom: 6px;">{str(m).strip()}</li>' for m in lifestyle[:3]])
@@ -512,7 +509,7 @@ def format_summary(response: dict, elapsed: float) -> str:
512
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
513
  </div>
514
  """
515
-
516
  followup = recs.get("follow_up", []) if isinstance(recs, dict) else []
517
  if followup and len(followup) > 0:
518
  items = "".join([f'<li style="margin-bottom: 6px;">{str(f).strip()}</li>' for f in followup[:3]])
@@ -522,10 +519,10 @@ def format_summary(response: dict, elapsed: float) -> str:
522
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
523
  </div>
524
  """
525
-
526
  # Add default recommendations if none provided
527
  if not rec_sections:
528
- rec_sections = f"""
529
  <div style="margin-bottom: 12px;">
530
  <h5 style="margin: 0 0 8px 0; color: #2563eb;">📋 General Recommendations</h5>
531
  <ul style="margin: 0; padding-left: 20px; color: #475569;">
@@ -535,7 +532,7 @@ def format_summary(response: dict, elapsed: float) -> str:
535
  </ul>
536
  </div>
537
  """
538
-
539
  if rec_sections:
540
  parts.append(f"""
541
  <div style="background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); border-radius: 12px; padding: 16px; margin-bottom: 16px;">
@@ -543,7 +540,7 @@ def format_summary(response: dict, elapsed: float) -> str:
543
  {rec_sections}
544
  </div>
545
  """)
546
-
547
  # Disease Explanation
548
  explanation = response.get("disease_explanation", {})
549
  if explanation and isinstance(explanation, dict):
@@ -555,7 +552,7 @@ def format_summary(response: dict, elapsed: float) -> str:
555
  <p style="margin: 0; color: #475569; line-height: 1.6;">{pathophys[:600]}{'...' if len(pathophys) > 600 else ''}</p>
556
  </div>
557
  """)
558
-
559
  # Conversational Summary
560
  conv_summary = response.get("conversational_summary", "")
561
  if conv_summary:
@@ -565,7 +562,7 @@ def format_summary(response: dict, elapsed: float) -> str:
565
  <p style="margin: 0; color: #475569; line-height: 1.6;">{conv_summary[:1000]}</p>
566
  </div>
567
  """)
568
-
569
  # Footer
570
  parts.append(f"""
571
  <div style="border-top: 1px solid #e2e8f0; padding-top: 16px; margin-top: 8px; text-align: center;">
@@ -577,7 +574,7 @@ def format_summary(response: dict, elapsed: float) -> str:
577
  </p>
578
  </div>
579
  """)
580
-
581
  return "\n".join(parts)
582
 
583
 
@@ -606,10 +603,10 @@ def _get_rag_service():
606
  _rag_service_error = None
607
 
608
  try:
 
609
  from src.services.agents.agentic_rag import AgenticRAGService
610
  from src.services.agents.context import AgenticContext
611
  from src.services.retrieval.factory import make_retriever
612
- from src.llm_config import get_synthesizer
613
 
614
  llm = get_synthesizer()
615
  retriever = make_retriever() # auto-detects FAISS
@@ -637,8 +634,8 @@ def _get_rag_service():
637
 
638
  def _fallback_qa(question: str, context_text: str = "") -> str:
639
  """Direct retriever+LLM fallback when agentic pipeline is unavailable."""
640
- from src.services.retrieval.factory import make_retriever
641
  from src.llm_config import get_synthesizer
 
642
 
643
  retriever = make_retriever()
644
  search_query = f"{context_text} {question}" if context_text.strip() else question
@@ -727,41 +724,53 @@ def answer_medical_question(
727
 
728
  except Exception as exc:
729
  logger.exception(f"Q&A error: {exc}")
730
- error_msg = f"❌ Error: {str(exc)}"
731
  history = (chat_history or []) + [(question, error_msg)]
732
  return error_msg, history
733
 
734
 
735
- def streaming_answer(question: str, context: str = ""):
736
  """Stream answer using the full agentic RAG pipeline.
737
  Falls back to direct retriever+LLM if the pipeline is unavailable.
738
  """
 
739
  if not question.strip():
740
- yield ""
741
  return
742
 
743
- groq_key, google_key = get_api_keys()
 
744
  if not groq_key and not google_key:
745
- yield "❌ Please add your GROQ_API_KEY or GOOGLE_API_KEY in Space Settings → Secrets."
 
746
  return
747
 
 
 
 
 
 
 
748
  setup_llm_provider()
749
 
750
  try:
751
- yield "🛡️ Checking medical domain relevance...\n\n"
 
752
 
753
  start_time = time.time()
754
 
755
  rag_service = _get_rag_service()
756
  if rag_service is not None:
757
- yield "🛡️ Checking medical domain relevance...\n🔍 Retrieving medical documents...\n\n"
 
758
  result = rag_service.ask(query=question, patient_context=context)
759
  answer = result.get("final_answer", "")
760
  guardrail = result.get("guardrail_score")
761
  docs_relevant = len(result.get("relevant_documents", []))
762
  docs_retrieved = len(result.get("retrieved_documents", []))
763
  else:
764
- yield "🔍 Searching medical knowledge base...\n📚 Retrieving relevant documents...\n\n"
 
765
  answer = _fallback_qa(question, context)
766
  guardrail = None
767
  docs_relevant = 0
@@ -770,7 +779,8 @@ def streaming_answer(question: str, context: str = ""):
770
  if not answer:
771
  answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
772
 
773
- yield "🛡️ Guardrail ✓\n🔍 Retrieved ✓\n📊 Graded ✓\n💭 Generating response...\n\n"
 
774
 
775
  elapsed = time.time() - start_time
776
 
@@ -779,9 +789,10 @@ def streaming_answer(question: str, context: str = ""):
779
  accumulated = ""
780
  for i, word in enumerate(words):
781
  accumulated += word + " "
782
- if i % 5 == 0:
783
- yield accumulated
784
- time.sleep(0.02)
 
785
 
786
  # Final response with metadata
787
  meta_parts = [f"⏱️ {elapsed:.1f}s"]
@@ -792,15 +803,34 @@ def streaming_answer(question: str, context: str = ""):
792
  meta_parts.append("🤖 Agentic RAG" if rag_service else "🤖 RAG")
793
  meta_line = " | ".join(meta_parts)
794
 
795
- yield f"""{answer}
796
-
797
- ---
798
- *{meta_line}*
799
- """
800
 
801
  except Exception as exc:
802
  logger.exception(f"Streaming Q&A error: {exc}")
803
- yield f"❌ Error: {str(exc)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
 
806
  # ---------------------------------------------------------------------------
@@ -1039,7 +1069,7 @@ footer { display: none !important; }
1039
 
1040
  def create_demo() -> gr.Blocks:
1041
  """Create the Gradio Blocks interface with modern medical UI."""
1042
-
1043
  with gr.Blocks(
1044
  title="Agentic RagBot - Medical Biomarker Analysis",
1045
  theme=gr.themes.Soft(
@@ -1065,7 +1095,7 @@ def create_demo() -> gr.Blocks:
1065
  ),
1066
  css=CUSTOM_CSS,
1067
  ) as demo:
1068
-
1069
  # ===== HEADER =====
1070
  gr.HTML("""
1071
  <div class="header-container">
@@ -1079,7 +1109,7 @@ def create_demo() -> gr.Blocks:
1079
  </div>
1080
  </div>
1081
  """)
1082
-
1083
  # ===== API KEY INFO =====
1084
  gr.HTML("""
1085
  <div class="info-banner">
@@ -1096,20 +1126,20 @@ def create_demo() -> gr.Blocks:
1096
  </div>
1097
  </div>
1098
  """)
1099
-
1100
  # ===== MAIN TABS =====
1101
  with gr.Tabs() as main_tabs:
1102
-
1103
  # ==================== TAB 1: BIOMARKER ANALYSIS ====================
1104
  with gr.Tab("🔬 Biomarker Analysis", id="biomarker-tab"):
1105
-
1106
  # ===== MAIN CONTENT =====
1107
  with gr.Row(equal_height=False):
1108
-
1109
  # ----- LEFT PANEL: INPUT -----
1110
  with gr.Column(scale=2, min_width=400):
1111
  gr.HTML('<div class="section-title">📝 Enter Your Biomarkers</div>')
1112
-
1113
  with gr.Group():
1114
  input_text = gr.Textbox(
1115
  label="",
@@ -1118,31 +1148,31 @@ def create_demo() -> gr.Blocks:
1118
  max_lines=12,
1119
  show_label=False,
1120
  )
1121
-
1122
  with gr.Row():
1123
  analyze_btn = gr.Button(
1124
- "🔬 Analyze Biomarkers",
1125
- variant="primary",
1126
  size="lg",
1127
  scale=3,
1128
  )
1129
  clear_btn = gr.Button(
1130
- "🗑️ Clear",
1131
  variant="secondary",
1132
  size="lg",
1133
  scale=1,
1134
  )
1135
-
1136
  # Status display
1137
  status_output = gr.Markdown(
1138
  value="",
1139
  elem_classes="status-box"
1140
  )
1141
-
1142
  # Quick Examples
1143
  gr.HTML('<div class="section-title" style="margin-top: 24px;">⚡ Quick Examples</div>')
1144
  gr.HTML('<p style="color: #64748b; font-size: 0.9em; margin-bottom: 12px;">Click any example to load it instantly</p>')
1145
-
1146
  examples = gr.Examples(
1147
  examples=[
1148
  ["Glucose: 185, HbA1c: 8.2, Cholesterol: 245, LDL: 165"],
@@ -1154,7 +1184,7 @@ def create_demo() -> gr.Blocks:
1154
  inputs=input_text,
1155
  label="",
1156
  )
1157
-
1158
  # Supported Biomarkers
1159
  with gr.Accordion("📊 Supported Biomarkers", open=False):
1160
  gr.HTML("""
@@ -1185,11 +1215,11 @@ def create_demo() -> gr.Blocks:
1185
  </div>
1186
  </div>
1187
  """)
1188
-
1189
  # ----- RIGHT PANEL: RESULTS -----
1190
  with gr.Column(scale=3, min_width=500):
1191
  gr.HTML('<div class="section-title">📊 Analysis Results</div>')
1192
-
1193
  with gr.Tabs() as result_tabs:
1194
  with gr.Tab("📋 Summary", id="summary"):
1195
  summary_output = gr.Markdown(
@@ -1202,7 +1232,7 @@ def create_demo() -> gr.Blocks:
1202
  """,
1203
  elem_classes="summary-output"
1204
  )
1205
-
1206
  with gr.Tab("🔍 Detailed JSON", id="json"):
1207
  details_output = gr.Code(
1208
  label="",
@@ -1210,10 +1240,10 @@ def create_demo() -> gr.Blocks:
1210
  lines=30,
1211
  show_label=False,
1212
  )
1213
-
1214
  # ==================== TAB 2: MEDICAL Q&A ====================
1215
  with gr.Tab("💬 Medical Q&A", id="qa-tab"):
1216
-
1217
  gr.HTML("""
1218
  <div style="margin-bottom: 20px;">
1219
  <h3 style="color: #1e3a5f; margin: 0 0 8px 0;">💬 Medical Q&A Assistant</h3>
@@ -1222,7 +1252,7 @@ def create_demo() -> gr.Blocks:
1222
  </p>
1223
  </div>
1224
  """)
1225
-
1226
  with gr.Row(equal_height=False):
1227
  with gr.Column(scale=1):
1228
  qa_context = gr.Textbox(
@@ -1231,6 +1261,11 @@ def create_demo() -> gr.Blocks:
1231
  lines=3,
1232
  max_lines=6,
1233
  )
 
 
 
 
 
1234
  qa_question = gr.Textbox(
1235
  label="Your Question",
1236
  placeholder="Ask any medical question...\n• What do my elevated glucose levels indicate?\n• Should I be concerned about my HbA1c of 7.5%?\n• What lifestyle changes help with prediabetes?",
@@ -1246,11 +1281,11 @@ def create_demo() -> gr.Blocks:
1246
  )
1247
  qa_clear_btn = gr.Button(
1248
  "🗑️ Clear",
1249
- variant="secondary",
1250
  size="lg",
1251
  scale=1,
1252
  )
1253
-
1254
  # Quick question examples
1255
  gr.HTML('<h4 style="margin-top: 16px; color: #1e3a5f;">Example Questions</h4>')
1256
  qa_examples = gr.Examples(
@@ -1263,42 +1298,54 @@ def create_demo() -> gr.Blocks:
1263
  inputs=[qa_question, qa_context],
1264
  label="",
1265
  )
1266
-
1267
  with gr.Column(scale=2):
1268
  gr.HTML('<h4 style="color: #1e3a5f; margin-bottom: 12px;">📝 Answer</h4>')
1269
- qa_answer = gr.Markdown(
1270
- value="""
1271
- <div style="text-align: center; padding: 40px 20px; color: #94a3b8;">
1272
- <div style="font-size: 3em; margin-bottom: 12px;">💬</div>
1273
- <h3 style="color: #64748b; font-weight: 500;">Ask a Medical Question</h3>
1274
- <p>Enter your question on the left and click <strong>Ask Question</strong> to get evidence-based answers.</p>
1275
- </div>
1276
- """,
1277
  elem_classes="qa-output"
1278
  )
1279
-
1280
  # Q&A Event Handlers
1281
  qa_submit_btn.click(
1282
  fn=streaming_answer,
1283
- inputs=[qa_question, qa_context],
1284
  outputs=qa_answer,
1285
  show_progress="minimal",
 
 
 
1286
  )
1287
-
1288
  qa_clear_btn.click(
1289
- fn=lambda: ("", "", """
1290
- <div style="text-align: center; padding: 40px 20px; color: #94a3b8;">
1291
- <div style="font-size: 3em; margin-bottom: 12px;">💬</div>
1292
- <h3 style="color: #64748b; font-weight: 500;">Ask a Medical Question</h3>
1293
- <p>Enter your question on the left and click <strong>Ask Question</strong> to get evidence-based answers.</p>
1294
- </div>
1295
- """),
1296
- outputs=[qa_question, qa_context, qa_answer],
1297
  )
1298
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1299
  # ===== HOW IT WORKS =====
1300
  gr.HTML('<div class="section-title" style="margin-top: 32px;">🤖 How It Works</div>')
1301
-
1302
  gr.HTML("""
1303
  <div class="agent-grid">
1304
  <div class="agent-card">
@@ -1327,7 +1374,7 @@ def create_demo() -> gr.Blocks:
1327
  </div>
1328
  </div>
1329
  """)
1330
-
1331
  # ===== DISCLAIMER =====
1332
  gr.HTML("""
1333
  <div class="disclaimer">
@@ -1337,7 +1384,7 @@ def create_demo() -> gr.Blocks:
1337
  clinical guidelines and may not account for your specific medical history.
1338
  </div>
1339
  """)
1340
-
1341
  # ===== FOOTER =====
1342
  gr.HTML("""
1343
  <div style="text-align: center; padding: 24px; color: #94a3b8; font-size: 0.85em; margin-top: 24px;">
@@ -1352,7 +1399,7 @@ def create_demo() -> gr.Blocks:
1352
  </p>
1353
  </div>
1354
  """)
1355
-
1356
  # ===== EVENT HANDLERS =====
1357
  analyze_btn.click(
1358
  fn=analyze_biomarkers,
@@ -1360,7 +1407,7 @@ def create_demo() -> gr.Blocks:
1360
  outputs=[summary_output, details_output, status_output],
1361
  show_progress="full",
1362
  )
1363
-
1364
  clear_btn.click(
1365
  fn=lambda: ("", """
1366
  <div style="text-align: center; padding: 60px 20px; color: #94a3b8;">
@@ -1371,7 +1418,7 @@ def create_demo() -> gr.Blocks:
1371
  """, "", ""),
1372
  outputs=[input_text, summary_output, details_output, status_output],
1373
  )
1374
-
1375
  return demo
1376
 
1377
 
@@ -1381,9 +1428,9 @@ def create_demo() -> gr.Blocks:
1381
 
1382
  if __name__ == "__main__":
1383
  logger.info("Starting MediGuard AI Gradio App...")
1384
-
1385
  demo = create_demo()
1386
-
1387
  # Launch with HF Spaces compatible settings
1388
  demo.launch(
1389
  server_name="0.0.0.0",
 
37
  import time
38
  import traceback
39
  from pathlib import Path
40
+ from typing import Any
41
 
42
  # Ensure project root is in path
43
  _project_root = str(Path(__file__).parent.parent)
 
114
  """
115
  groq_key, google_key = get_api_keys()
116
  provider = None
117
+
118
  if groq_key:
119
  os.environ["LLM_PROVIDER"] = "groq"
120
  os.environ["GROQ_API_KEY"] = groq_key
 
127
  os.environ["GEMINI_MODEL"] = get_gemini_model()
128
  provider = "gemini"
129
  logger.info(f"Configured Gemini provider with model: {get_gemini_model()}")
130
+
131
  # Set up embedding provider
132
  embedding_provider = get_embedding_provider()
133
  os.environ["EMBEDDING_PROVIDER"] = embedding_provider
134
+
135
  # If Jina is configured, set the API key
136
  jina_key = get_jina_api_key()
137
  if jina_key:
138
  os.environ["JINA_API_KEY"] = jina_key
139
  os.environ["EMBEDDING__JINA_API_KEY"] = jina_key
140
  logger.info("Jina embeddings configured")
141
+
142
  # Set up Langfuse if enabled
143
  if is_langfuse_enabled():
144
  os.environ["LANGFUSE__ENABLED"] = "true"
 
147
  if val:
148
  os.environ[var] = val
149
  logger.info("Langfuse observability enabled")
150
+
151
  return provider
152
 
153
 
 
192
  def get_guild():
193
  """Lazy initialization of the Clinical Insight Guild."""
194
  global _guild, _guild_error, _guild_provider
195
+
196
  # Check if we need to reinitialize (provider changed)
197
  current_provider = os.getenv("LLM_PROVIDER")
198
  if _guild_provider and _guild_provider != current_provider:
199
  logger.info(f"Provider changed from {_guild_provider} to {current_provider}, reinitializing...")
200
  reset_guild()
201
+
202
  if _guild is not None:
203
  return _guild
204
+
205
  if _guild_error is not None:
206
  # Don't cache errors forever - allow retry
207
  logger.warning("Previous initialization failed, retrying...")
208
  _guild_error = None
209
+
210
  try:
211
  logger.info("Initializing Clinical Insight Guild...")
212
  logger.info(f" LLM_PROVIDER: {os.getenv('LLM_PROVIDER', 'not set')}")
 
214
  logger.info(f" GOOGLE_API_KEY: {'✓ set' if os.getenv('GOOGLE_API_KEY') else '✗ not set'}")
215
  logger.info(f" EMBEDDING_PROVIDER: {os.getenv('EMBEDDING_PROVIDER', 'huggingface')}")
216
  logger.info(f" JINA_API_KEY: {'✓ set' if os.getenv('JINA_API_KEY') else '✗ not set'}")
217
+
218
  start = time.time()
219
+
220
  from src.workflow import create_guild
221
  _guild = create_guild()
222
  _guild_provider = current_provider
223
+
224
  elapsed = time.time() - start
225
  logger.info(f"Guild initialized in {elapsed:.1f}s")
226
  return _guild
227
+
228
  except Exception as exc:
229
  logger.error(f"Failed to initialize guild: {exc}")
230
  _guild_error = exc
 
237
 
238
  # Import shared parsing and prediction logic
239
  from src.shared_utils import (
 
240
  get_primary_prediction,
241
+ parse_biomarkers,
 
 
242
  )
243
 
244
 
 
264
  <p style="margin: 8px 0 0 0; color: #64748b;">Please enter biomarkers to analyze.</p>
265
  </div>
266
  """
267
+
268
  # Check API key dynamically (HF injects secrets after startup)
269
  groq_key, google_key = get_api_keys()
270
+
271
  if not groq_key and not google_key:
272
  return "", "", """
273
  <div style="background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); border: 1px solid #ef4444; border-radius: 10px; padding: 16px;">
 
294
  </details>
295
  </div>
296
  """
297
+
298
  # Setup provider based on available key
299
  provider = setup_llm_provider()
300
  logger.info(f"Using LLM provider: {provider}")
301
+
302
  try:
303
  progress(0.1, desc="📝 Parsing biomarkers...")
304
  biomarkers = parse_biomarkers(input_text)
305
+
306
  if not biomarkers:
307
  return "", "", """
308
  <div style="background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); border: 1px solid #fbbf24; border-radius: 10px; padding: 16px;">
 
314
  </ul>
315
  </div>
316
  """
317
+
318
  progress(0.2, desc="🔧 Initializing AI agents...")
319
+
320
  # Initialize guild
321
  guild = get_guild()
322
+
323
  # Prepare input
324
  from src.state import PatientInput
325
+
326
  # Auto-generate prediction based on common patterns
327
  prediction = auto_predict(biomarkers)
328
+
329
  patient_input = PatientInput(
330
  biomarkers=biomarkers,
331
  model_prediction=prediction,
332
  patient_context={"patient_id": "HF_User", "source": "huggingface_spaces"}
333
  )
334
+
335
  progress(0.4, desc="🤖 Running Clinical Insight Guild...")
336
+
337
  # Run analysis
338
  start = time.time()
339
  result = guild.run(patient_input)
340
  elapsed = time.time() - start
341
+
342
  progress(0.9, desc="✨ Formatting results...")
343
+
344
  # Extract response
345
  final_response = result.get("final_response", {})
346
+
347
  # Format summary
348
  summary = format_summary(final_response, elapsed)
349
+
350
  # Format details
351
  details = json.dumps(final_response, indent=2, default=str)
352
+
353
  status = f"""
354
  <div style="background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%); border: 1px solid #10b981; border-radius: 10px; padding: 12px; display: flex; align-items: center; gap: 10px;">
355
  <span style="font-size: 1.5em;">✅</span>
 
359
  </div>
360
  </div>
361
  """
362
+
363
  return summary, details, status
364
+
365
  except Exception as exc:
366
  logger.error(f"Analysis error: {exc}", exc_info=True)
367
  error_msg = f"""
 
381
  """Format the analysis response as clean markdown with black text."""
382
  if not response:
383
  return "❌ **No analysis results available.**"
384
+
385
  parts = []
386
+
387
  # Header with primary finding and confidence
388
  primary = response.get("primary_finding", "Analysis Complete")
389
  confidence = response.get("confidence", {})
390
  conf_score = confidence.get("overall_score", 0) if isinstance(confidence, dict) else 0
391
+
392
  # Determine severity
393
  severity = response.get("severity", "low")
394
  severity_config = {
 
398
  "low": ("🟢", "#16a34a", "#f0fdf4")
399
  }
400
  emoji, color, bg_color = severity_config.get(severity, severity_config["low"])
401
+
402
  # Build confidence display
403
  conf_badge = ""
404
  if conf_score:
405
  conf_pct = int(conf_score * 100)
406
  conf_color = "#16a34a" if conf_pct >= 80 else "#ca8a04" if conf_pct >= 60 else "#dc2626"
407
  conf_badge = f'<span style="background: {conf_color}; color: white; padding: 4px 12px; border-radius: 20px; font-size: 0.85em; margin-left: 12px;">{conf_pct}% confidence</span>'
408
+
409
  parts.append(f"""
410
  <div style="background: linear-gradient(135deg, {bg_color} 0%, white 100%); border-left: 4px solid {color}; border-radius: 12px; padding: 20px; margin-bottom: 20px;">
411
  <div style="display: flex; align-items: center; flex-wrap: wrap;">
 
414
  {conf_badge}
415
  </div>
416
  </div>""")
417
+
418
  # Critical Alerts
419
  alerts = response.get("safety_alerts", [])
420
  if alerts:
 
424
  alert_items += f'<li><strong>{alert.get("alert_type", "Alert")}:</strong> {alert.get("message", "")}</li>'
425
  else:
426
  alert_items += f'<li>{alert}</li>'
427
+
428
  parts.append(f"""
429
  <div style="background: linear-gradient(135deg, #fef2f2 0%, #fee2e2 100%); border: 1px solid #fecaca; border-radius: 12px; padding: 16px; margin-bottom: 16px;">
430
  <h4 style="margin: 0 0 12px 0; color: #dc2626; display: flex; align-items: center; gap: 8px;">
 
433
  <ul style="margin: 0; padding-left: 20px; color: #991b1b;">{alert_items}</ul>
434
  </div>
435
  """)
436
+
437
  # Key Findings
438
  findings = response.get("key_findings", [])
439
  if findings:
 
444
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{finding_items}</ul>
445
  </div>
446
  """)
447
+
448
  # Biomarker Flags - as a visual grid
449
  flags = response.get("biomarker_flags", [])
450
  if flags and len(flags) > 0:
 
457
  continue
458
  status = flag.get("status", "normal").lower()
459
  value = flag.get("value", flag.get("result", "N/A"))
460
+
461
  status_styles = {
462
  "critical": ("🔴", "#dc2626", "#fef2f2"),
463
  "high": ("🔴", "#dc2626", "#fef2f2"),
 
466
  "normal": ("🟢", "#16a34a", "#f0fdf4")
467
  }
468
  s_emoji, s_color, s_bg = status_styles.get(status, status_styles["normal"])
469
+
470
  flag_cards += f"""
471
  <div style="background: {s_bg}; border: 1px solid {s_color}33; border-radius: 8px; padding: 12px; text-align: center;">
472
  <div style="font-size: 1.2em;">{s_emoji}</div>
 
475
  <div style="font-size: 0.75em; color: #64748b; text-transform: capitalize;">{status}</div>
476
  </div>
477
  """
478
+
479
  if flag_cards: # Only show section if we have cards
480
  parts.append(f"""
481
  <div style="margin-bottom: 16px;">
 
485
  </div>
486
  </div>
487
  """)
488
+
489
  # Recommendations - organized sections
490
  recs = response.get("recommendations", {})
491
  rec_sections = ""
492
+
493
  immediate = recs.get("immediate_actions", []) if isinstance(recs, dict) else []
494
  if immediate and len(immediate) > 0:
495
  items = "".join([f'<li style="margin-bottom: 6px;">{str(a).strip()}</li>' for a in immediate[:3]])
 
499
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
500
  </div>
501
  """
502
+
503
  lifestyle = recs.get("lifestyle_modifications", []) if isinstance(recs, dict) else []
504
  if lifestyle and len(lifestyle) > 0:
505
  items = "".join([f'<li style="margin-bottom: 6px;">{str(m).strip()}</li>' for m in lifestyle[:3]])
 
509
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
510
  </div>
511
  """
512
+
513
  followup = recs.get("follow_up", []) if isinstance(recs, dict) else []
514
  if followup and len(followup) > 0:
515
  items = "".join([f'<li style="margin-bottom: 6px;">{str(f).strip()}</li>' for f in followup[:3]])
 
519
  <ul style="margin: 0; padding-left: 20px; color: #475569;">{items}</ul>
520
  </div>
521
  """
522
+
523
  # Add default recommendations if none provided
524
  if not rec_sections:
525
+ rec_sections = """
526
  <div style="margin-bottom: 12px;">
527
  <h5 style="margin: 0 0 8px 0; color: #2563eb;">📋 General Recommendations</h5>
528
  <ul style="margin: 0; padding-left: 20px; color: #475569;">
 
532
  </ul>
533
  </div>
534
  """
535
+
536
  if rec_sections:
537
  parts.append(f"""
538
  <div style="background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); border-radius: 12px; padding: 16px; margin-bottom: 16px;">
 
540
  {rec_sections}
541
  </div>
542
  """)
543
+
544
  # Disease Explanation
545
  explanation = response.get("disease_explanation", {})
546
  if explanation and isinstance(explanation, dict):
 
552
  <p style="margin: 0; color: #475569; line-height: 1.6;">{pathophys[:600]}{'...' if len(pathophys) > 600 else ''}</p>
553
  </div>
554
  """)
555
+
556
  # Conversational Summary
557
  conv_summary = response.get("conversational_summary", "")
558
  if conv_summary:
 
562
  <p style="margin: 0; color: #475569; line-height: 1.6;">{conv_summary[:1000]}</p>
563
  </div>
564
  """)
565
+
566
  # Footer
567
  parts.append(f"""
568
  <div style="border-top: 1px solid #e2e8f0; padding-top: 16px; margin-top: 8px; text-align: center;">
 
574
  </p>
575
  </div>
576
  """)
577
+
578
  return "\n".join(parts)
579
 
580
 
 
603
  _rag_service_error = None
604
 
605
  try:
606
+ from src.llm_config import get_synthesizer
607
  from src.services.agents.agentic_rag import AgenticRAGService
608
  from src.services.agents.context import AgenticContext
609
  from src.services.retrieval.factory import make_retriever
 
610
 
611
  llm = get_synthesizer()
612
  retriever = make_retriever() # auto-detects FAISS
 
634
 
635
  def _fallback_qa(question: str, context_text: str = "") -> str:
636
  """Direct retriever+LLM fallback when agentic pipeline is unavailable."""
 
637
  from src.llm_config import get_synthesizer
638
+ from src.services.retrieval.factory import make_retriever
639
 
640
  retriever = make_retriever()
641
  search_query = f"{context_text} {question}" if context_text.strip() else question
 
724
 
725
  except Exception as exc:
726
  logger.exception(f"Q&A error: {exc}")
727
+ error_msg = f"❌ Error: {exc!s}"
728
  history = (chat_history or []) + [(question, error_msg)]
729
  return error_msg, history
730
 
731
 
732
+ def streaming_answer(question: str, context: str, history: list, model: str):
733
  """Stream answer using the full agentic RAG pipeline.
734
  Falls back to direct retriever+LLM if the pipeline is unavailable.
735
  """
736
+ history = history or []
737
  if not question.strip():
738
+ yield history
739
  return
740
 
741
+ history.append((question, ""))
742
+
743
  if not groq_key and not google_key:
744
+ history[-1] = (question, "❌ Please add your GROQ_API_KEY or GOOGLE_API_KEY in Space Settings → Secrets.")
745
+ yield history
746
  return
747
 
748
+ # Update provider if model changed (simplified handling for UI demo)
749
+ if "gemini" in model.lower():
750
+ os.environ["LLM_PROVIDER"] = "gemini"
751
+ else:
752
+ os.environ["LLM_PROVIDER"] = "groq"
753
+
754
  setup_llm_provider()
755
 
756
  try:
757
+ history[-1] = (question, "🛡️ Checking medical domain relevance...\n\n")
758
+ yield history
759
 
760
  start_time = time.time()
761
 
762
  rag_service = _get_rag_service()
763
  if rag_service is not None:
764
+ history[-1] = (question, "🛡️ Checking medical domain relevance...\n🔍 Retrieving medical documents...\n\n")
765
+ yield history
766
  result = rag_service.ask(query=question, patient_context=context)
767
  answer = result.get("final_answer", "")
768
  guardrail = result.get("guardrail_score")
769
  docs_relevant = len(result.get("relevant_documents", []))
770
  docs_retrieved = len(result.get("retrieved_documents", []))
771
  else:
772
+ history[-1] = (question, "🔍 Searching medical knowledge base...\n📚 Retrieving relevant documents...\n\n")
773
+ yield history
774
  answer = _fallback_qa(question, context)
775
  guardrail = None
776
  docs_relevant = 0
 
779
  if not answer:
780
  answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
781
 
782
+ history[-1] = (question, "🛡️ Guardrail ✓\n🔍 Retrieved ✓\n📊 Graded ✓\n💭 Generating response...\n\n")
783
+ yield history
784
 
785
  elapsed = time.time() - start_time
786
 
 
789
  accumulated = ""
790
  for i, word in enumerate(words):
791
  accumulated += word + " "
792
+ if i % 10 == 0:
793
+ history[-1] = (question, accumulated)
794
+ yield history
795
+ time.sleep(0.01)
796
 
797
  # Final response with metadata
798
  meta_parts = [f"⏱️ {elapsed:.1f}s"]
 
803
  meta_parts.append("🤖 Agentic RAG" if rag_service else "🤖 RAG")
804
  meta_line = " | ".join(meta_parts)
805
 
806
+ final_msg = f"{answer}\n\n---\n*{meta_line}*\n"
807
+ history[-1] = (question, final_msg)
808
+ yield history
 
 
809
 
810
  except Exception as exc:
811
  logger.exception(f"Streaming Q&A error: {exc}")
812
+ history[-1] = (question, f"❌ Error: {exc!s}")
813
+ yield history
814
+
815
+
816
+ def hf_search(query: str, mode: str):
817
+ """Direct fast-retrieval for the HF Space Knowledge tab."""
818
+ if not query.strip():
819
+ return "Please enter a query."
820
+ try:
821
+ from src.services.retrieval.factory import make_retriever
822
+ retriever = make_retriever()
823
+ docs = retriever.retrieve(query, top_k=5)
824
+ if not docs:
825
+ return "No results found."
826
+ parts = []
827
+ for i, doc in enumerate(docs, 1):
828
+ title = doc.metadata.get("title", doc.metadata.get("source_file", "Untitled"))
829
+ score = doc.score if hasattr(doc, 'score') else 0.0
830
+ parts.append(f"**[{i}] {title}** (score: {score:.3f})\n{doc.content}\n")
831
+ return "\n---\n".join(parts)
832
+ except Exception as exc:
833
+ return f"Error: {exc}"
834
 
835
 
836
  # ---------------------------------------------------------------------------
 
1069
 
1070
  def create_demo() -> gr.Blocks:
1071
  """Create the Gradio Blocks interface with modern medical UI."""
1072
+
1073
  with gr.Blocks(
1074
  title="Agentic RagBot - Medical Biomarker Analysis",
1075
  theme=gr.themes.Soft(
 
1095
  ),
1096
  css=CUSTOM_CSS,
1097
  ) as demo:
1098
+
1099
  # ===== HEADER =====
1100
  gr.HTML("""
1101
  <div class="header-container">
 
1109
  </div>
1110
  </div>
1111
  """)
1112
+
1113
  # ===== API KEY INFO =====
1114
  gr.HTML("""
1115
  <div class="info-banner">
 
1126
  </div>
1127
  </div>
1128
  """)
1129
+
1130
  # ===== MAIN TABS =====
1131
  with gr.Tabs() as main_tabs:
1132
+
1133
  # ==================== TAB 1: BIOMARKER ANALYSIS ====================
1134
  with gr.Tab("🔬 Biomarker Analysis", id="biomarker-tab"):
1135
+
1136
  # ===== MAIN CONTENT =====
1137
  with gr.Row(equal_height=False):
1138
+
1139
  # ----- LEFT PANEL: INPUT -----
1140
  with gr.Column(scale=2, min_width=400):
1141
  gr.HTML('<div class="section-title">📝 Enter Your Biomarkers</div>')
1142
+
1143
  with gr.Group():
1144
  input_text = gr.Textbox(
1145
  label="",
 
1148
  max_lines=12,
1149
  show_label=False,
1150
  )
1151
+
1152
  with gr.Row():
1153
  analyze_btn = gr.Button(
1154
+ "🔬 Analyze Biomarkers",
1155
+ variant="primary",
1156
  size="lg",
1157
  scale=3,
1158
  )
1159
  clear_btn = gr.Button(
1160
+ "🗑️ Clear",
1161
  variant="secondary",
1162
  size="lg",
1163
  scale=1,
1164
  )
1165
+
1166
  # Status display
1167
  status_output = gr.Markdown(
1168
  value="",
1169
  elem_classes="status-box"
1170
  )
1171
+
1172
  # Quick Examples
1173
  gr.HTML('<div class="section-title" style="margin-top: 24px;">⚡ Quick Examples</div>')
1174
  gr.HTML('<p style="color: #64748b; font-size: 0.9em; margin-bottom: 12px;">Click any example to load it instantly</p>')
1175
+
1176
  examples = gr.Examples(
1177
  examples=[
1178
  ["Glucose: 185, HbA1c: 8.2, Cholesterol: 245, LDL: 165"],
 
1184
  inputs=input_text,
1185
  label="",
1186
  )
1187
+
1188
  # Supported Biomarkers
1189
  with gr.Accordion("📊 Supported Biomarkers", open=False):
1190
  gr.HTML("""
 
1215
  </div>
1216
  </div>
1217
  """)
1218
+
1219
  # ----- RIGHT PANEL: RESULTS -----
1220
  with gr.Column(scale=3, min_width=500):
1221
  gr.HTML('<div class="section-title">📊 Analysis Results</div>')
1222
+
1223
  with gr.Tabs() as result_tabs:
1224
  with gr.Tab("📋 Summary", id="summary"):
1225
  summary_output = gr.Markdown(
 
1232
  """,
1233
  elem_classes="summary-output"
1234
  )
1235
+
1236
  with gr.Tab("🔍 Detailed JSON", id="json"):
1237
  details_output = gr.Code(
1238
  label="",
 
1240
  lines=30,
1241
  show_label=False,
1242
  )
1243
+
1244
  # ==================== TAB 2: MEDICAL Q&A ====================
1245
  with gr.Tab("💬 Medical Q&A", id="qa-tab"):
1246
+
1247
  gr.HTML("""
1248
  <div style="margin-bottom: 20px;">
1249
  <h3 style="color: #1e3a5f; margin: 0 0 8px 0;">💬 Medical Q&A Assistant</h3>
 
1252
  </p>
1253
  </div>
1254
  """)
1255
+
1256
  with gr.Row(equal_height=False):
1257
  with gr.Column(scale=1):
1258
  qa_context = gr.Textbox(
 
1261
  lines=3,
1262
  max_lines=6,
1263
  )
1264
+ qa_model = gr.Dropdown(
1265
+ choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
1266
+ value="llama-3.3-70b-versatile",
1267
+ label="LLM Provider/Model"
1268
+ )
1269
  qa_question = gr.Textbox(
1270
  label="Your Question",
1271
  placeholder="Ask any medical question...\n• What do my elevated glucose levels indicate?\n• Should I be concerned about my HbA1c of 7.5%?\n• What lifestyle changes help with prediabetes?",
 
1281
  )
1282
  qa_clear_btn = gr.Button(
1283
  "🗑️ Clear",
1284
+ variant="secondary",
1285
  size="lg",
1286
  scale=1,
1287
  )
1288
+
1289
  # Quick question examples
1290
  gr.HTML('<h4 style="margin-top: 16px; color: #1e3a5f;">Example Questions</h4>')
1291
  qa_examples = gr.Examples(
 
1298
  inputs=[qa_question, qa_context],
1299
  label="",
1300
  )
1301
+
1302
  with gr.Column(scale=2):
1303
  gr.HTML('<h4 style="color: #1e3a5f; margin-bottom: 12px;">📝 Answer</h4>')
1304
+ qa_answer = gr.Chatbot(
1305
+ label="Medical Q&A History",
1306
+ height=600,
 
 
 
 
 
1307
  elem_classes="qa-output"
1308
  )
1309
+
1310
  # Q&A Event Handlers
1311
  qa_submit_btn.click(
1312
  fn=streaming_answer,
1313
+ inputs=[qa_question, qa_context, qa_answer, qa_model],
1314
  outputs=qa_answer,
1315
  show_progress="minimal",
1316
+ ).then(
1317
+ fn=lambda: "",
1318
+ outputs=qa_question
1319
  )
1320
+
1321
  qa_clear_btn.click(
1322
+ fn=lambda: ([], ""),
1323
+ outputs=[qa_answer, qa_question],
 
 
 
 
 
 
1324
  )
1325
+
1326
+ # ==================== TAB 3: SEARCH KNOWLEDGE BASE ====================
1327
+ with gr.Tab("🔍 Search Knowledge Base", id="search-tab"):
1328
+ with gr.Row():
1329
+ search_input = gr.Textbox(
1330
+ label="Search Query",
1331
+ placeholder="e.g., diabetes management guidelines",
1332
+ lines=2,
1333
+ scale=3
1334
+ )
1335
+ search_mode = gr.Radio(
1336
+ choices=["hybrid", "bm25", "vector"],
1337
+ value="hybrid",
1338
+ label="Search Strategy",
1339
+ scale=1
1340
+ )
1341
+ search_btn = gr.Button("Search", variant="primary")
1342
+ search_output = gr.Textbox(label="Results", lines=20, interactive=False)
1343
+
1344
+ search_btn.click(fn=hf_search, inputs=[search_input, search_mode], outputs=search_output)
1345
+
1346
  # ===== HOW IT WORKS =====
1347
  gr.HTML('<div class="section-title" style="margin-top: 32px;">🤖 How It Works</div>')
1348
+
1349
  gr.HTML("""
1350
  <div class="agent-grid">
1351
  <div class="agent-card">
 
1374
  </div>
1375
  </div>
1376
  """)
1377
+
1378
  # ===== DISCLAIMER =====
1379
  gr.HTML("""
1380
  <div class="disclaimer">
 
1384
  clinical guidelines and may not account for your specific medical history.
1385
  </div>
1386
  """)
1387
+
1388
  # ===== FOOTER =====
1389
  gr.HTML("""
1390
  <div style="text-align: center; padding: 24px; color: #94a3b8; font-size: 0.85em; margin-top: 24px;">
 
1399
  </p>
1400
  </div>
1401
  """)
1402
+
1403
  # ===== EVENT HANDLERS =====
1404
  analyze_btn.click(
1405
  fn=analyze_biomarkers,
 
1407
  outputs=[summary_output, details_output, status_output],
1408
  show_progress="full",
1409
  )
1410
+
1411
  clear_btn.click(
1412
  fn=lambda: ("", """
1413
  <div style="text-align: center; padding: 60px 20px; color: #94a3b8;">
 
1418
  """, "", ""),
1419
  outputs=[input_text, summary_output, details_output, status_output],
1420
  )
1421
+
1422
  return demo
1423
 
1424
 
 
1428
 
1429
  if __name__ == "__main__":
1430
  logger.info("Starting MediGuard AI Gradio App...")
1431
+
1432
  demo = create_demo()
1433
+
1434
  # Launch with HF Spaces compatible settings
1435
  demo.launch(
1436
  server_name="0.0.0.0",
pytest.ini CHANGED
@@ -2,3 +2,6 @@
2
  filterwarnings =
3
  ignore::langchain_core._api.deprecation.LangChainDeprecationWarning
4
  ignore:.*class.*HuggingFaceEmbeddings.*was deprecated.*:DeprecationWarning
 
 
 
 
2
  filterwarnings =
3
  ignore::langchain_core._api.deprecation.LangChainDeprecationWarning
4
  ignore:.*class.*HuggingFaceEmbeddings.*was deprecated.*:DeprecationWarning
5
+
6
+ markers =
7
+ integration: mark a test as an integration test.
requirements.txt DELETED
@@ -1,41 +0,0 @@
1
- # MediGuard AI RAG-Helper - Dependencies
2
-
3
- # Core Framework
4
- langchain>=0.1.0
5
- langgraph>=0.0.20
6
- langchain-community>=0.0.13
7
- langchain-core>=0.1.10
8
-
9
- # LLM Providers (Cloud - FREE tiers available)
10
- langchain-groq>=0.1.0 # Groq API (FREE tier, llama-3.3-70b)
11
- langchain-google-genai>=1.0.0 # Google Gemini (FREE tier)
12
-
13
- # Local LLM (optional, for offline use)
14
- # ollama>=0.1.6
15
-
16
- # Vector Store & Embeddings
17
- faiss-cpu>=1.9.0
18
- sentence-transformers>=2.2.2
19
-
20
- # Document Processing
21
- pypdf>=3.17.4
22
- pydantic>=2.5.3
23
-
24
- # Data Handling
25
- pandas>=2.1.4
26
-
27
- # Environment & Configuration
28
- python-dotenv>=1.0.0
29
-
30
- # Utilities
31
- numpy>=1.26.2
32
- matplotlib>=3.8.2
33
-
34
- # Optional: improved readability scoring for evaluations
35
- textstat>=0.7.3
36
-
37
- # Optional: HuggingFace embedding provider
38
- # langchain-huggingface>=0.0.1
39
-
40
- # Optional: Ollama local LLM provider
41
- # langchain-ollama>=0.0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/chat.py CHANGED
@@ -4,9 +4,9 @@ Enables natural language conversation with the RAG system
4
  """
5
 
6
  import json
7
- import sys
8
- import os
9
  import logging
 
 
10
  import warnings
11
 
12
  # ── Silence HuggingFace / transformers noise BEFORE any ML library is loaded ──
@@ -21,9 +21,9 @@ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
21
  warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
22
  # ─────────────────────────────────────────────────────────────────────────────
23
 
24
- from pathlib import Path
25
- from typing import Dict, Any, Tuple
26
  from datetime import datetime
 
 
27
 
28
  # Set UTF-8 encoding for Windows console
29
  if sys.platform == 'win32':
@@ -40,11 +40,11 @@ if sys.platform == 'win32':
40
  sys.path.insert(0, str(Path(__file__).parent.parent))
41
 
42
  from langchain_core.prompts import ChatPromptTemplate
 
43
  from src.biomarker_normalization import normalize_biomarker_name
44
  from src.llm_config import get_chat_model
45
- from src.workflow import create_guild
46
  from src.state import PatientInput
47
-
48
 
49
  # ============================================================================
50
  # BIOMARKER EXTRACTION PROMPT
@@ -82,7 +82,7 @@ If you cannot find any biomarkers, return {{"biomarkers": {{}}, "patient_context
82
  # Component 1: Biomarker Extraction
83
  # ============================================================================
84
 
85
- def _parse_llm_json(content: str) -> Dict[str, Any]:
86
  """Parse JSON payload from LLM output with fallback recovery."""
87
  text = content.strip()
88
 
@@ -101,7 +101,7 @@ def _parse_llm_json(content: str) -> Dict[str, Any]:
101
  raise
102
 
103
 
104
- def extract_biomarkers(user_message: str) -> Tuple[Dict[str, float], Dict[str, Any]]:
105
  """
106
  Extract biomarker values from natural language using LLM.
107
 
@@ -111,17 +111,17 @@ def extract_biomarkers(user_message: str) -> Tuple[Dict[str, float], Dict[str, A
111
  try:
112
  llm = get_chat_model(temperature=0.0)
113
  prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
114
-
115
  chain = prompt | llm
116
  response = chain.invoke({"user_message": user_message})
117
-
118
  # Parse JSON from LLM response
119
  content = response.content.strip()
120
-
121
  extracted = _parse_llm_json(content)
122
  biomarkers = extracted.get("biomarkers", {})
123
  patient_context = extracted.get("patient_context", {})
124
-
125
  # Normalize biomarker names
126
  normalized = {}
127
  for key, value in biomarkers.items():
@@ -131,12 +131,12 @@ def extract_biomarkers(user_message: str) -> Tuple[Dict[str, float], Dict[str, A
131
  except (ValueError, TypeError) as e:
132
  print(f"⚠️ Skipping invalid value for {key}: {value} (error: {e})")
133
  continue
134
-
135
  # Clean up patient context (remove null values)
136
  patient_context = {k: v for k, v in patient_context.items() if v is not None}
137
-
138
  return normalized, patient_context
139
-
140
  except Exception as e:
141
  print(f"⚠️ Extraction failed: {e}")
142
  import traceback
@@ -148,7 +148,7 @@ def extract_biomarkers(user_message: str) -> Tuple[Dict[str, float], Dict[str, A
148
  # Component 2: Disease Prediction
149
  # ============================================================================
150
 
151
- def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
152
  """
153
  Simple rule-based disease prediction based on key biomarkers.
154
  """
@@ -159,15 +159,15 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
159
  "Thrombocytopenia": 0.0,
160
  "Thalassemia": 0.0
161
  }
162
-
163
  # Helper: check both abbreviated and normalized biomarker names
164
  # Returns None when biomarker is not present (avoids false triggers)
165
  def _get(name, *alt_names):
166
- val = biomarkers.get(name, None)
167
  if val is not None:
168
  return val
169
  for alt in alt_names:
170
- val = biomarkers.get(alt, None)
171
  if val is not None:
172
  return val
173
  return None
@@ -181,7 +181,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
181
  scores["Diabetes"] += 0.2
182
  if hba1c is not None and hba1c >= 6.5:
183
  scores["Diabetes"] += 0.5
184
-
185
  # Anemia indicators
186
  hemoglobin = _get("Hemoglobin")
187
  mcv = _get("Mean Corpuscular Volume", "MCV")
@@ -191,7 +191,7 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
191
  scores["Anemia"] += 0.2
192
  if mcv is not None and mcv < 80:
193
  scores["Anemia"] += 0.2
194
-
195
  # Heart disease indicators
196
  cholesterol = _get("Cholesterol")
197
  troponin = _get("Troponin")
@@ -202,32 +202,32 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
202
  scores["Heart Disease"] += 0.6
203
  if ldl is not None and ldl > 190:
204
  scores["Heart Disease"] += 0.2
205
-
206
  # Thrombocytopenia indicators
207
  platelets = _get("Platelets")
208
  if platelets is not None and platelets < 150000:
209
  scores["Thrombocytopenia"] += 0.6
210
  if platelets is not None and platelets < 50000:
211
  scores["Thrombocytopenia"] += 0.3
212
-
213
  # Thalassemia indicators (complex, simplified here)
214
  if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
215
  scores["Thalassemia"] += 0.4
216
-
217
  # Find top prediction
218
  top_disease = max(scores, key=scores.get)
219
  confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
220
-
221
  if confidence == 0.0:
222
  top_disease = "Undetermined"
223
-
224
  # Normalize probabilities to sum to 1.0
225
  total = sum(scores.values())
226
  if total > 0:
227
  probabilities = {k: v / total for k, v in scores.items()}
228
  else:
229
  probabilities = {k: 1.0 / len(scores) for k in scores}
230
-
231
  return {
232
  "disease": top_disease,
233
  "confidence": confidence,
@@ -235,14 +235,14 @@ def predict_disease_simple(biomarkers: Dict[str, float]) -> Dict[str, Any]:
235
  }
236
 
237
 
238
- def predict_disease_llm(biomarkers: Dict[str, float], patient_context: Dict) -> Dict[str, Any]:
239
  """
240
  Use LLM to predict most likely disease based on biomarker pattern.
241
  Falls back to rule-based if LLM fails.
242
  """
243
  try:
244
  llm = get_chat_model(temperature=0.0)
245
-
246
  prompt = f"""You are a medical AI assistant. Based on these biomarker values,
247
  predict the most likely disease from: Diabetes, Anemia, Heart Disease, Thrombocytopenia, Thalassemia.
248
 
@@ -265,18 +265,18 @@ Return ONLY valid JSON (no other text):
265
  }}
266
  }}
267
  """
268
-
269
  response = llm.invoke(prompt)
270
  content = response.content.strip()
271
-
272
  prediction = _parse_llm_json(content)
273
-
274
  # Validate required fields
275
  if "disease" in prediction and "confidence" in prediction and "probabilities" in prediction:
276
  return prediction
277
  else:
278
  raise ValueError("Invalid prediction format")
279
-
280
  except Exception as e:
281
  print(f"⚠️ LLM prediction failed ({e}), using rule-based fallback")
282
  import traceback
@@ -288,7 +288,7 @@ Return ONLY valid JSON (no other text):
288
  # Component 3: Conversational Formatter
289
  # ============================================================================
290
 
291
- def _coerce_to_dict(obj) -> Dict:
292
  """Convert a Pydantic model or arbitrary object to a plain dict."""
293
  if isinstance(obj, dict):
294
  return obj
@@ -299,7 +299,7 @@ def _coerce_to_dict(obj) -> Dict:
299
  return {}
300
 
301
 
302
- def format_conversational(result: Dict[str, Any], user_name: str = "there") -> str:
303
  """
304
  Format technical JSON output into conversational response.
305
  """
@@ -313,22 +313,22 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
313
  confidence = result.get("confidence_assessment", {}) or {}
314
  # Normalize: items may be Pydantic SafetyAlert objects or plain dicts
315
  alerts = [_coerce_to_dict(a) for a in (result.get("safety_alerts") or [])]
316
-
317
  disease = prediction.get("primary_disease", "Unknown")
318
  conf_score = prediction.get("confidence", 0.0)
319
-
320
  # Build conversational response
321
  response = []
322
-
323
  # 1. Greeting and main finding
324
  response.append(f"Hi {user_name}! 👋\n")
325
- response.append(f"Based on your biomarkers, I analyzed your results.\n")
326
-
327
  # 2. Primary diagnosis with confidence
328
  emoji = "🔴" if conf_score >= 0.8 else "🟡" if conf_score >= 0.6 else "🟢"
329
  response.append(f"{emoji} **Primary Finding:** {disease}")
330
  response.append(f" Confidence: {conf_score:.0%}\n")
331
-
332
  # 3. Critical safety alerts (if any)
333
  critical_alerts = [a for a in alerts if a.get("severity") == "CRITICAL"]
334
  if critical_alerts:
@@ -337,7 +337,7 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
337
  response.append(f" • {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}")
338
  response.append(f" → {alert.get('action', 'Consult healthcare provider')}")
339
  response.append("")
340
-
341
  # 4. Key drivers explanation
342
  key_drivers = prediction.get("key_drivers", [])
343
  if key_drivers:
@@ -351,7 +351,7 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
351
  explanation = explanation[:147] + "..."
352
  response.append(f" • **{biomarker}** ({value}): {explanation}")
353
  response.append("")
354
-
355
  # 5. What to do next (immediate actions)
356
  immediate = recommendations.get("immediate_actions", [])
357
  if immediate:
@@ -359,7 +359,7 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
359
  for i, action in enumerate(immediate[:3], 1):
360
  response.append(f" {i}. {action}")
361
  response.append("")
362
-
363
  # 6. Lifestyle recommendations
364
  lifestyle = recommendations.get("lifestyle_changes", [])
365
  if lifestyle:
@@ -367,11 +367,11 @@ def format_conversational(result: Dict[str, Any], user_name: str = "there") -> s
367
  for i, change in enumerate(lifestyle[:3], 1):
368
  response.append(f" {i}. {change}")
369
  response.append("")
370
-
371
  # 7. Disclaimer
372
  response.append("ℹ️ **Important:** This is an AI-assisted analysis, NOT medical advice.")
373
  response.append(" Please consult a healthcare professional for proper diagnosis and treatment.\n")
374
-
375
  return "\n".join(response)
376
 
377
 
@@ -397,7 +397,7 @@ def run_example_case(guild):
397
  """Run example diabetes patient case"""
398
  print("\n📋 Running Example: Type 2 Diabetes Patient")
399
  print(" 52-year-old male with elevated glucose and HbA1c\n")
400
-
401
  example_biomarkers = {
402
  "Glucose": 185.0,
403
  "HbA1c": 8.2,
@@ -411,7 +411,7 @@ def run_example_case(guild):
411
  "Systolic Blood Pressure": 145,
412
  "Diastolic Blood Pressure": 92
413
  }
414
-
415
  prediction = {
416
  "disease": "Diabetes",
417
  "confidence": 0.87,
@@ -423,16 +423,16 @@ def run_example_case(guild):
423
  "Thalassemia": 0.01
424
  }
425
  }
426
-
427
  patient_input = PatientInput(
428
  biomarkers=example_biomarkers,
429
  model_prediction=prediction,
430
  patient_context={"age": 52, "gender": "male", "bmi": 31.2}
431
  )
432
-
433
  print("🔄 Running analysis...\n")
434
  result = guild.run(patient_input)
435
-
436
  response = format_conversational(result.get("final_response", result), "there")
437
  print("\n" + "="*70)
438
  print("🤖 RAG-BOT:")
@@ -441,7 +441,7 @@ def run_example_case(guild):
441
  print("="*70 + "\n")
442
 
443
 
444
- def save_report(result: Dict, biomarkers: Dict):
445
  """Save detailed JSON report to file"""
446
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
447
 
@@ -505,7 +505,7 @@ def chat_interface():
505
  print(" 3. Type 'help' for biomarker list")
506
  print(" 4. Type 'quit' to exit\n")
507
  print("="*70 + "\n")
508
-
509
  # Initialize guild (one-time setup)
510
  print("🔧 Initializing medical knowledge system...")
511
  try:
@@ -518,78 +518,78 @@ def chat_interface():
518
  print(" • Vector store exists (run: python scripts/setup_embeddings.py)")
519
  print(" • Internet connection is available for cloud LLM")
520
  return
521
-
522
  # Main conversation loop
523
  conversation_history = []
524
  user_name = "there"
525
-
526
  while True:
527
  try:
528
  # Get user input
529
  user_input = input("You: ").strip()
530
-
531
  if not user_input:
532
  continue
533
-
534
  # Handle special commands
535
  if user_input.lower() in ['quit', 'exit', 'q']:
536
  print("\n👋 Thank you for using MediGuard AI. Stay healthy!")
537
  break
538
-
539
  if user_input.lower() == 'help':
540
  print_biomarker_help()
541
  continue
542
-
543
  if user_input.lower() == 'example':
544
  run_example_case(guild)
545
  continue
546
-
547
  # Extract biomarkers from natural language
548
  print("\n🔍 Analyzing your input...")
549
  biomarkers, patient_context = extract_biomarkers(user_input)
550
-
551
  if not biomarkers:
552
  print("❌ I couldn't find any biomarker values in your message.")
553
  print(" Try: 'My glucose is 140 and HbA1c is 7.5'")
554
  print(" Or type 'help' to see all biomarkers I can analyze.\n")
555
  continue
556
-
557
  print(f"✅ Found {len(biomarkers)} biomarker(s): {', '.join(biomarkers.keys())}")
558
-
559
  # Check if we have enough biomarkers (minimum 2)
560
  if len(biomarkers) < 2:
561
  print("⚠️ I need at least 2 biomarkers for a reliable analysis.")
562
  print(" Can you provide more values?\n")
563
  continue
564
-
565
  # Generate disease prediction
566
  print("🧠 Predicting likely condition...")
567
  prediction = predict_disease_llm(biomarkers, patient_context)
568
  print(f"✅ Predicted: {prediction['disease']} ({prediction['confidence']:.0%} confidence)")
569
-
570
  # Create PatientInput
571
  patient_input = PatientInput(
572
  biomarkers=biomarkers,
573
  model_prediction=prediction,
574
  patient_context=patient_context if patient_context else {"source": "chat"}
575
  )
576
-
577
  # Run full RAG workflow
578
  print("📚 Consulting medical knowledge base...")
579
  print(" (This may take 15-25 seconds...)\n")
580
-
581
  result = guild.run(patient_input)
582
-
583
  # Format conversational response
584
  response = format_conversational(result.get("final_response", result), user_name)
585
-
586
  # Display response
587
  print("\n" + "="*70)
588
  print("🤖 RAG-BOT:")
589
  print("="*70)
590
  print(response)
591
  print("="*70 + "\n")
592
-
593
  # Save to history
594
  conversation_history.append({
595
  "user_input": user_input,
@@ -597,16 +597,16 @@ def chat_interface():
597
  "prediction": prediction,
598
  "result": result
599
  })
600
-
601
  # Ask if user wants to save report
602
  save_choice = input("💾 Save detailed report to file? (y/n): ").strip().lower()
603
  if save_choice == 'y':
604
  save_report(result, biomarkers)
605
-
606
  print("\nYou can:")
607
  print(" • Enter more biomarkers for a new analysis")
608
  print(" • Type 'quit' to exit\n")
609
-
610
  except KeyboardInterrupt:
611
  print("\n\n👋 Interrupted. Thank you for using MediGuard AI!")
612
  break
 
4
  """
5
 
6
  import json
 
 
7
  import logging
8
+ import os
9
+ import sys
10
  import warnings
11
 
12
  # ── Silence HuggingFace / transformers noise BEFORE any ML library is loaded ──
 
21
  warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
22
  # ─────────────────────────────────────────────────────────────────────────────
23
 
 
 
24
  from datetime import datetime
25
+ from pathlib import Path
26
+ from typing import Any
27
 
28
  # Set UTF-8 encoding for Windows console
29
  if sys.platform == 'win32':
 
40
  sys.path.insert(0, str(Path(__file__).parent.parent))
41
 
42
  from langchain_core.prompts import ChatPromptTemplate
43
+
44
  from src.biomarker_normalization import normalize_biomarker_name
45
  from src.llm_config import get_chat_model
 
46
  from src.state import PatientInput
47
+ from src.workflow import create_guild
48
 
49
  # ============================================================================
50
  # BIOMARKER EXTRACTION PROMPT
 
82
  # Component 1: Biomarker Extraction
83
  # ============================================================================
84
 
85
+ def _parse_llm_json(content: str) -> dict[str, Any]:
86
  """Parse JSON payload from LLM output with fallback recovery."""
87
  text = content.strip()
88
 
 
101
  raise
102
 
103
 
104
+ def extract_biomarkers(user_message: str) -> tuple[dict[str, float], dict[str, Any]]:
105
  """
106
  Extract biomarker values from natural language using LLM.
107
 
 
111
  try:
112
  llm = get_chat_model(temperature=0.0)
113
  prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT)
114
+
115
  chain = prompt | llm
116
  response = chain.invoke({"user_message": user_message})
117
+
118
  # Parse JSON from LLM response
119
  content = response.content.strip()
120
+
121
  extracted = _parse_llm_json(content)
122
  biomarkers = extracted.get("biomarkers", {})
123
  patient_context = extracted.get("patient_context", {})
124
+
125
  # Normalize biomarker names
126
  normalized = {}
127
  for key, value in biomarkers.items():
 
131
  except (ValueError, TypeError) as e:
132
  print(f"⚠️ Skipping invalid value for {key}: {value} (error: {e})")
133
  continue
134
+
135
  # Clean up patient context (remove null values)
136
  patient_context = {k: v for k, v in patient_context.items() if v is not None}
137
+
138
  return normalized, patient_context
139
+
140
  except Exception as e:
141
  print(f"⚠️ Extraction failed: {e}")
142
  import traceback
 
148
  # Component 2: Disease Prediction
149
  # ============================================================================
150
 
151
+ def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]:
152
  """
153
  Simple rule-based disease prediction based on key biomarkers.
154
  """
 
159
  "Thrombocytopenia": 0.0,
160
  "Thalassemia": 0.0
161
  }
162
+
163
  # Helper: check both abbreviated and normalized biomarker names
164
  # Returns None when biomarker is not present (avoids false triggers)
165
  def _get(name, *alt_names):
166
+ val = biomarkers.get(name)
167
  if val is not None:
168
  return val
169
  for alt in alt_names:
170
+ val = biomarkers.get(alt)
171
  if val is not None:
172
  return val
173
  return None
 
181
  scores["Diabetes"] += 0.2
182
  if hba1c is not None and hba1c >= 6.5:
183
  scores["Diabetes"] += 0.5
184
+
185
  # Anemia indicators
186
  hemoglobin = _get("Hemoglobin")
187
  mcv = _get("Mean Corpuscular Volume", "MCV")
 
191
  scores["Anemia"] += 0.2
192
  if mcv is not None and mcv < 80:
193
  scores["Anemia"] += 0.2
194
+
195
  # Heart disease indicators
196
  cholesterol = _get("Cholesterol")
197
  troponin = _get("Troponin")
 
202
  scores["Heart Disease"] += 0.6
203
  if ldl is not None and ldl > 190:
204
  scores["Heart Disease"] += 0.2
205
+
206
  # Thrombocytopenia indicators
207
  platelets = _get("Platelets")
208
  if platelets is not None and platelets < 150000:
209
  scores["Thrombocytopenia"] += 0.6
210
  if platelets is not None and platelets < 50000:
211
  scores["Thrombocytopenia"] += 0.3
212
+
213
  # Thalassemia indicators (complex, simplified here)
214
  if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
215
  scores["Thalassemia"] += 0.4
216
+
217
  # Find top prediction
218
  top_disease = max(scores, key=scores.get)
219
  confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation
220
+
221
  if confidence == 0.0:
222
  top_disease = "Undetermined"
223
+
224
  # Normalize probabilities to sum to 1.0
225
  total = sum(scores.values())
226
  if total > 0:
227
  probabilities = {k: v / total for k, v in scores.items()}
228
  else:
229
  probabilities = {k: 1.0 / len(scores) for k in scores}
230
+
231
  return {
232
  "disease": top_disease,
233
  "confidence": confidence,
 
235
  }
236
 
237
 
238
+ def predict_disease_llm(biomarkers: dict[str, float], patient_context: dict) -> dict[str, Any]:
239
  """
240
  Use LLM to predict most likely disease based on biomarker pattern.
241
  Falls back to rule-based if LLM fails.
242
  """
243
  try:
244
  llm = get_chat_model(temperature=0.0)
245
+
246
  prompt = f"""You are a medical AI assistant. Based on these biomarker values,
247
  predict the most likely disease from: Diabetes, Anemia, Heart Disease, Thrombocytopenia, Thalassemia.
248
 
 
265
  }}
266
  }}
267
  """
268
+
269
  response = llm.invoke(prompt)
270
  content = response.content.strip()
271
+
272
  prediction = _parse_llm_json(content)
273
+
274
  # Validate required fields
275
  if "disease" in prediction and "confidence" in prediction and "probabilities" in prediction:
276
  return prediction
277
  else:
278
  raise ValueError("Invalid prediction format")
279
+
280
  except Exception as e:
281
  print(f"⚠️ LLM prediction failed ({e}), using rule-based fallback")
282
  import traceback
 
288
  # Component 3: Conversational Formatter
289
  # ============================================================================
290
 
291
+ def _coerce_to_dict(obj) -> dict:
292
  """Convert a Pydantic model or arbitrary object to a plain dict."""
293
  if isinstance(obj, dict):
294
  return obj
 
299
  return {}
300
 
301
 
302
+ def format_conversational(result: dict[str, Any], user_name: str = "there") -> str:
303
  """
304
  Format technical JSON output into conversational response.
305
  """
 
313
  confidence = result.get("confidence_assessment", {}) or {}
314
  # Normalize: items may be Pydantic SafetyAlert objects or plain dicts
315
  alerts = [_coerce_to_dict(a) for a in (result.get("safety_alerts") or [])]
316
+
317
  disease = prediction.get("primary_disease", "Unknown")
318
  conf_score = prediction.get("confidence", 0.0)
319
+
320
  # Build conversational response
321
  response = []
322
+
323
  # 1. Greeting and main finding
324
  response.append(f"Hi {user_name}! 👋\n")
325
+ response.append("Based on your biomarkers, I analyzed your results.\n")
326
+
327
  # 2. Primary diagnosis with confidence
328
  emoji = "🔴" if conf_score >= 0.8 else "🟡" if conf_score >= 0.6 else "🟢"
329
  response.append(f"{emoji} **Primary Finding:** {disease}")
330
  response.append(f" Confidence: {conf_score:.0%}\n")
331
+
332
  # 3. Critical safety alerts (if any)
333
  critical_alerts = [a for a in alerts if a.get("severity") == "CRITICAL"]
334
  if critical_alerts:
 
337
  response.append(f" • {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}")
338
  response.append(f" → {alert.get('action', 'Consult healthcare provider')}")
339
  response.append("")
340
+
341
  # 4. Key drivers explanation
342
  key_drivers = prediction.get("key_drivers", [])
343
  if key_drivers:
 
351
  explanation = explanation[:147] + "..."
352
  response.append(f" • **{biomarker}** ({value}): {explanation}")
353
  response.append("")
354
+
355
  # 5. What to do next (immediate actions)
356
  immediate = recommendations.get("immediate_actions", [])
357
  if immediate:
 
359
  for i, action in enumerate(immediate[:3], 1):
360
  response.append(f" {i}. {action}")
361
  response.append("")
362
+
363
  # 6. Lifestyle recommendations
364
  lifestyle = recommendations.get("lifestyle_changes", [])
365
  if lifestyle:
 
367
  for i, change in enumerate(lifestyle[:3], 1):
368
  response.append(f" {i}. {change}")
369
  response.append("")
370
+
371
  # 7. Disclaimer
372
  response.append("ℹ️ **Important:** This is an AI-assisted analysis, NOT medical advice.")
373
  response.append(" Please consult a healthcare professional for proper diagnosis and treatment.\n")
374
+
375
  return "\n".join(response)
376
 
377
 
 
397
  """Run example diabetes patient case"""
398
  print("\n📋 Running Example: Type 2 Diabetes Patient")
399
  print(" 52-year-old male with elevated glucose and HbA1c\n")
400
+
401
  example_biomarkers = {
402
  "Glucose": 185.0,
403
  "HbA1c": 8.2,
 
411
  "Systolic Blood Pressure": 145,
412
  "Diastolic Blood Pressure": 92
413
  }
414
+
415
  prediction = {
416
  "disease": "Diabetes",
417
  "confidence": 0.87,
 
423
  "Thalassemia": 0.01
424
  }
425
  }
426
+
427
  patient_input = PatientInput(
428
  biomarkers=example_biomarkers,
429
  model_prediction=prediction,
430
  patient_context={"age": 52, "gender": "male", "bmi": 31.2}
431
  )
432
+
433
  print("🔄 Running analysis...\n")
434
  result = guild.run(patient_input)
435
+
436
  response = format_conversational(result.get("final_response", result), "there")
437
  print("\n" + "="*70)
438
  print("🤖 RAG-BOT:")
 
441
  print("="*70 + "\n")
442
 
443
 
444
+ def save_report(result: dict, biomarkers: dict):
445
  """Save detailed JSON report to file"""
446
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
447
 
 
505
  print(" 3. Type 'help' for biomarker list")
506
  print(" 4. Type 'quit' to exit\n")
507
  print("="*70 + "\n")
508
+
509
  # Initialize guild (one-time setup)
510
  print("🔧 Initializing medical knowledge system...")
511
  try:
 
518
  print(" • Vector store exists (run: python scripts/setup_embeddings.py)")
519
  print(" • Internet connection is available for cloud LLM")
520
  return
521
+
522
  # Main conversation loop
523
  conversation_history = []
524
  user_name = "there"
525
+
526
  while True:
527
  try:
528
  # Get user input
529
  user_input = input("You: ").strip()
530
+
531
  if not user_input:
532
  continue
533
+
534
  # Handle special commands
535
  if user_input.lower() in ['quit', 'exit', 'q']:
536
  print("\n👋 Thank you for using MediGuard AI. Stay healthy!")
537
  break
538
+
539
  if user_input.lower() == 'help':
540
  print_biomarker_help()
541
  continue
542
+
543
  if user_input.lower() == 'example':
544
  run_example_case(guild)
545
  continue
546
+
547
  # Extract biomarkers from natural language
548
  print("\n🔍 Analyzing your input...")
549
  biomarkers, patient_context = extract_biomarkers(user_input)
550
+
551
  if not biomarkers:
552
  print("❌ I couldn't find any biomarker values in your message.")
553
  print(" Try: 'My glucose is 140 and HbA1c is 7.5'")
554
  print(" Or type 'help' to see all biomarkers I can analyze.\n")
555
  continue
556
+
557
  print(f"✅ Found {len(biomarkers)} biomarker(s): {', '.join(biomarkers.keys())}")
558
+
559
  # Check if we have enough biomarkers (minimum 2)
560
  if len(biomarkers) < 2:
561
  print("⚠️ I need at least 2 biomarkers for a reliable analysis.")
562
  print(" Can you provide more values?\n")
563
  continue
564
+
565
  # Generate disease prediction
566
  print("🧠 Predicting likely condition...")
567
  prediction = predict_disease_llm(biomarkers, patient_context)
568
  print(f"✅ Predicted: {prediction['disease']} ({prediction['confidence']:.0%} confidence)")
569
+
570
  # Create PatientInput
571
  patient_input = PatientInput(
572
  biomarkers=biomarkers,
573
  model_prediction=prediction,
574
  patient_context=patient_context if patient_context else {"source": "chat"}
575
  )
576
+
577
  # Run full RAG workflow
578
  print("📚 Consulting medical knowledge base...")
579
  print(" (This may take 15-25 seconds...)\n")
580
+
581
  result = guild.run(patient_input)
582
+
583
  # Format conversational response
584
  response = format_conversational(result.get("final_response", result), user_name)
585
+
586
  # Display response
587
  print("\n" + "="*70)
588
  print("🤖 RAG-BOT:")
589
  print("="*70)
590
  print(response)
591
  print("="*70 + "\n")
592
+
593
  # Save to history
594
  conversation_history.append({
595
  "user_input": user_input,
 
597
  "prediction": prediction,
598
  "result": result
599
  })
600
+
601
  # Ask if user wants to save report
602
  save_choice = input("💾 Save detailed report to file? (y/n): ").strip().lower()
603
  if save_choice == 'y':
604
  save_report(result, biomarkers)
605
+
606
  print("\nYou can:")
607
  print(" • Enter more biomarkers for a new analysis")
608
  print(" • Type 'quit' to exit\n")
609
+
610
  except KeyboardInterrupt:
611
  print("\n\n👋 Interrupted. Thank you for using MediGuard AI!")
612
  break
scripts/monitor_test.py CHANGED
@@ -7,6 +7,6 @@ print("=" * 70)
7
  for i in range(60): # Check for 5 minutes
8
  time.sleep(5)
9
  print(f"[{i*5}s] Test still running...")
10
-
11
  print("\nTest should be complete or nearly complete.")
12
  print("Check terminal output for results.")
 
7
  for i in range(60): # Check for 5 minutes
8
  time.sleep(5)
9
  print(f"[{i*5}s] Test still running...")
10
+
11
  print("\nTest should be complete or nearly complete.")
12
  print("Check terminal output for results.")
scripts/setup_embeddings.py CHANGED
@@ -2,22 +2,22 @@
2
  Quick script to help set up Google API key for fast embeddings
3
  """
4
 
5
- import os
6
  from pathlib import Path
7
 
 
8
  def setup_google_api_key():
9
  """Interactive setup for Google API key"""
10
-
11
  print("="*70)
12
  print("Fast Embeddings Setup - Google Gemini API")
13
  print("="*70)
14
-
15
  print("\nWhy Google Gemini?")
16
  print(" - 100x faster than local Ollama (2 mins vs 30+ mins)")
17
  print(" - FREE for standard usage")
18
  print(" - High quality embeddings")
19
  print(" - Automatic fallback to Ollama if unavailable")
20
-
21
  print("\n" + "="*70)
22
  print("Step 1: Get Your Free API Key")
23
  print("="*70)
@@ -26,28 +26,28 @@ def setup_google_api_key():
26
  print("\n2. Sign in with Google account")
27
  print("3. Click 'Create API Key'")
28
  print("4. Copy the key (starts with 'AIza...')")
29
-
30
  input("\nPress ENTER when you have your API key ready...")
31
-
32
  api_key = input("\nPaste your Google API key here: ").strip()
33
-
34
  if not api_key:
35
  print("\nNo API key provided. Using local Ollama instead.")
36
  return False
37
-
38
  if not api_key.startswith("AIza"):
39
  print("\nWarning: Key doesn't start with 'AIza'. Are you sure this is correct?")
40
  confirm = input("Continue anyway? (y/n): ").strip().lower()
41
  if confirm != 'y':
42
  return False
43
-
44
  # Update .env file
45
  env_path = Path(".env")
46
-
47
  if env_path.exists():
48
- with open(env_path, 'r') as f:
49
  lines = f.readlines()
50
-
51
  # Update or add GOOGLE_API_KEY
52
  updated = False
53
  for i, line in enumerate(lines):
@@ -55,17 +55,17 @@ def setup_google_api_key():
55
  lines[i] = f'GOOGLE_API_KEY={api_key}\n'
56
  updated = True
57
  break
58
-
59
  if not updated:
60
  lines.insert(0, f'GOOGLE_API_KEY={api_key}\n')
61
-
62
  with open(env_path, 'w') as f:
63
  f.writelines(lines)
64
  else:
65
  # Create new .env file
66
  with open(env_path, 'w') as f:
67
  f.write(f'GOOGLE_API_KEY={api_key}\n')
68
-
69
  print("\nAPI key saved to .env file!")
70
  print("\n" + "="*70)
71
  print("Step 2: Build Vector Store")
@@ -74,7 +74,7 @@ def setup_google_api_key():
74
  print(" python src/pdf_processor.py")
75
  print("\nChoose option 1 (Google Gemini) when prompted.")
76
  print("\n" + "="*70)
77
-
78
  return True
79
 
80
 
 
2
  Quick script to help set up Google API key for fast embeddings
3
  """
4
 
 
5
  from pathlib import Path
6
 
7
+
8
  def setup_google_api_key():
9
  """Interactive setup for Google API key"""
10
+
11
  print("="*70)
12
  print("Fast Embeddings Setup - Google Gemini API")
13
  print("="*70)
14
+
15
  print("\nWhy Google Gemini?")
16
  print(" - 100x faster than local Ollama (2 mins vs 30+ mins)")
17
  print(" - FREE for standard usage")
18
  print(" - High quality embeddings")
19
  print(" - Automatic fallback to Ollama if unavailable")
20
+
21
  print("\n" + "="*70)
22
  print("Step 1: Get Your Free API Key")
23
  print("="*70)
 
26
  print("\n2. Sign in with Google account")
27
  print("3. Click 'Create API Key'")
28
  print("4. Copy the key (starts with 'AIza...')")
29
+
30
  input("\nPress ENTER when you have your API key ready...")
31
+
32
  api_key = input("\nPaste your Google API key here: ").strip()
33
+
34
  if not api_key:
35
  print("\nNo API key provided. Using local Ollama instead.")
36
  return False
37
+
38
  if not api_key.startswith("AIza"):
39
  print("\nWarning: Key doesn't start with 'AIza'. Are you sure this is correct?")
40
  confirm = input("Continue anyway? (y/n): ").strip().lower()
41
  if confirm != 'y':
42
  return False
43
+
44
  # Update .env file
45
  env_path = Path(".env")
46
+
47
  if env_path.exists():
48
+ with open(env_path) as f:
49
  lines = f.readlines()
50
+
51
  # Update or add GOOGLE_API_KEY
52
  updated = False
53
  for i, line in enumerate(lines):
 
55
  lines[i] = f'GOOGLE_API_KEY={api_key}\n'
56
  updated = True
57
  break
58
+
59
  if not updated:
60
  lines.insert(0, f'GOOGLE_API_KEY={api_key}\n')
61
+
62
  with open(env_path, 'w') as f:
63
  f.writelines(lines)
64
  else:
65
  # Create new .env file
66
  with open(env_path, 'w') as f:
67
  f.write(f'GOOGLE_API_KEY={api_key}\n')
68
+
69
  print("\nAPI key saved to .env file!")
70
  print("\n" + "="*70)
71
  print("Step 2: Build Vector Store")
 
74
  print(" python src/pdf_processor.py")
75
  print("\nChoose option 1 (Google Gemini) when prompted.")
76
  print("\n" + "="*70)
77
+
78
  return True
79
 
80
 
scripts/test_chat_demo.py CHANGED
@@ -4,7 +4,6 @@ Quick demo script to test the chatbot with pre-defined inputs
4
 
5
  import subprocess
6
  import sys
7
- from pathlib import Path
8
 
9
  # Test inputs
10
  test_cases = [
@@ -36,16 +35,16 @@ try:
36
  encoding='utf-8',
37
  errors='replace'
38
  )
39
-
40
  print("STDOUT:")
41
  print(result.stdout)
42
-
43
  if result.stderr:
44
  print("\nSTDERR:")
45
  print(result.stderr)
46
-
47
  print(f"\nExit code: {result.returncode}")
48
-
49
  except subprocess.TimeoutExpired:
50
  print("⚠️ Test timed out after 120 seconds")
51
  except Exception as e:
 
4
 
5
  import subprocess
6
  import sys
 
7
 
8
  # Test inputs
9
  test_cases = [
 
35
  encoding='utf-8',
36
  errors='replace'
37
  )
38
+
39
  print("STDOUT:")
40
  print(result.stdout)
41
+
42
  if result.stderr:
43
  print("\nSTDERR:")
44
  print(result.stderr)
45
+
46
  print(f"\nExit code: {result.returncode}")
47
+
48
  except subprocess.TimeoutExpired:
49
  print("⚠️ Test timed out after 120 seconds")
50
  except Exception as e:
scripts/test_extraction.py CHANGED
@@ -4,6 +4,7 @@ Quick test to verify biomarker extraction is working
4
 
5
  import sys
6
  from pathlib import Path
 
7
  sys.path.insert(0, str(Path(__file__).parent.parent))
8
 
9
  from scripts.chat import extract_biomarkers, predict_disease_llm
@@ -22,25 +23,25 @@ print("="*70)
22
  for i, test_input in enumerate(test_inputs, 1):
23
  print(f"\n[Test {i}] Input: '{test_input}'")
24
  print("-"*70)
25
-
26
  biomarkers, context = extract_biomarkers(test_input)
27
-
28
  if biomarkers:
29
  print(f"✅ SUCCESS: Found {len(biomarkers)} biomarkers")
30
  for name, value in biomarkers.items():
31
  print(f" - {name}: {value}")
32
-
33
  if context:
34
  print(f" Context: {context}")
35
-
36
  # Test prediction
37
  print("\n Testing prediction...")
38
  prediction = predict_disease_llm(biomarkers, context)
39
  print(f" Predicted: {prediction['disease']} ({prediction['confidence']:.0%})")
40
-
41
  else:
42
- print(f"❌ FAILED: No biomarkers extracted")
43
-
44
  print()
45
 
46
  print("="*70)
 
4
 
5
  import sys
6
  from pathlib import Path
7
+
8
  sys.path.insert(0, str(Path(__file__).parent.parent))
9
 
10
  from scripts.chat import extract_biomarkers, predict_disease_llm
 
23
  for i, test_input in enumerate(test_inputs, 1):
24
  print(f"\n[Test {i}] Input: '{test_input}'")
25
  print("-"*70)
26
+
27
  biomarkers, context = extract_biomarkers(test_input)
28
+
29
  if biomarkers:
30
  print(f"✅ SUCCESS: Found {len(biomarkers)} biomarkers")
31
  for name, value in biomarkers.items():
32
  print(f" - {name}: {value}")
33
+
34
  if context:
35
  print(f" Context: {context}")
36
+
37
  # Test prediction
38
  print("\n Testing prediction...")
39
  prediction = predict_disease_llm(biomarkers, context)
40
  print(f" Predicted: {prediction['disease']} ({prediction['confidence']:.0%})")
41
+
42
  else:
43
+ print("❌ FAILED: No biomarkers extracted")
44
+
45
  print()
46
 
47
  print("="*70)
src/agents/biomarker_analyzer.py CHANGED
@@ -3,19 +3,19 @@ MediGuard AI RAG-Helper
3
  Biomarker Analyzer Agent - Validates biomarker values and flags anomalies
4
  """
5
 
6
- from typing import Dict, List
7
- from src.state import GuildState, AgentOutput, BiomarkerFlag
8
  from src.biomarker_validator import BiomarkerValidator
9
  from src.llm_config import llm_config
 
10
 
11
 
12
  class BiomarkerAnalyzerAgent:
13
  """Agent that validates biomarker values and generates comprehensive analysis"""
14
-
15
  def __init__(self):
16
  self.validator = BiomarkerValidator()
17
  self.llm = llm_config.analyzer
18
-
19
  def analyze(self, state: GuildState) -> GuildState:
20
  """
21
  Main agent function to analyze biomarkers.
@@ -29,12 +29,12 @@ class BiomarkerAnalyzerAgent:
29
  print("\n" + "="*70)
30
  print("EXECUTING: Biomarker Analyzer Agent")
31
  print("="*70)
32
-
33
  biomarkers = state['patient_biomarkers']
34
  patient_context = state.get('patient_context', {})
35
  gender = patient_context.get('gender') # None if not provided — uses non-gender-specific ranges
36
  predicted_disease = state['model_prediction']['disease']
37
-
38
  # Validate all biomarkers
39
  print(f"\nValidating {len(biomarkers)} biomarkers...")
40
  flags, alerts = self.validator.validate_all(
@@ -42,13 +42,13 @@ class BiomarkerAnalyzerAgent:
42
  gender=gender,
43
  threshold_pct=state['sop'].biomarker_analyzer_threshold
44
  )
45
-
46
  # Get disease-relevant biomarkers
47
  relevant_biomarkers = self.validator.get_disease_relevant_biomarkers(predicted_disease)
48
-
49
  # Generate summary using LLM
50
  summary = self._generate_summary(biomarkers, flags, alerts, relevant_biomarkers, predicted_disease)
51
-
52
  findings = {
53
  "biomarker_flags": [flag.model_dump() for flag in flags],
54
  "safety_alerts": [alert.model_dump() for alert in alerts],
@@ -62,35 +62,35 @@ class BiomarkerAnalyzerAgent:
62
  agent_name="Biomarker Analyzer",
63
  findings=findings
64
  )
65
-
66
  # Update state
67
  print("\nAnalysis complete:")
68
  print(f" - {len(flags)} biomarkers validated")
69
  print(f" - {len([f for f in flags if f.status != 'NORMAL'])} out-of-range values")
70
  print(f" - {len(alerts)} safety alerts generated")
71
  print(f" - {len(relevant_biomarkers)} disease-relevant biomarkers identified")
72
-
73
  return {
74
  'agent_outputs': [output],
75
  'biomarker_flags': flags,
76
  'safety_alerts': alerts,
77
  'biomarker_analysis': findings
78
  }
79
-
80
  def _generate_summary(
81
  self,
82
- biomarkers: Dict[str, float],
83
- flags: List[BiomarkerFlag],
84
- alerts: List,
85
- relevant_biomarkers: List[str],
86
  disease: str
87
  ) -> str:
88
  """Generate a concise summary of biomarker findings"""
89
-
90
  # Count anomalies
91
  critical = [f for f in flags if 'CRITICAL' in f.status]
92
  high_low = [f for f in flags if f.status in ['HIGH', 'LOW']]
93
-
94
  prompt = f"""You are a medical data analyst. Provide a brief, clinical summary of these biomarker results.
95
 
96
  **Patient Context:**
@@ -115,24 +115,24 @@ Keep it concise and clinical."""
115
  except Exception as e:
116
  print(f"Warning: LLM summary generation failed: {e}")
117
  return f"Biomarker analysis complete. {len(critical)} critical values, {len(high_low)} out-of-range values detected."
118
-
119
  def _format_key_findings(self, critical, high_low, relevant):
120
  """Format findings for LLM prompt"""
121
  findings = []
122
-
123
  if critical:
124
  findings.append("CRITICAL VALUES:")
125
  for f in critical[:3]: # Top 3
126
  findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
127
-
128
  if high_low:
129
  findings.append("\nOUT-OF-RANGE VALUES:")
130
  for f in high_low[:5]: # Top 5
131
  findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
132
-
133
  if relevant:
134
  findings.append(f"\nDISEASE-RELEVANT BIOMARKERS: {', '.join(relevant[:5])}")
135
-
136
  return "\n".join(findings) if findings else "All biomarkers within normal range."
137
 
138
 
 
3
  Biomarker Analyzer Agent - Validates biomarker values and flags anomalies
4
  """
5
 
6
+
 
7
  from src.biomarker_validator import BiomarkerValidator
8
  from src.llm_config import llm_config
9
+ from src.state import AgentOutput, BiomarkerFlag, GuildState
10
 
11
 
12
  class BiomarkerAnalyzerAgent:
13
  """Agent that validates biomarker values and generates comprehensive analysis"""
14
+
15
  def __init__(self):
16
  self.validator = BiomarkerValidator()
17
  self.llm = llm_config.analyzer
18
+
19
  def analyze(self, state: GuildState) -> GuildState:
20
  """
21
  Main agent function to analyze biomarkers.
 
29
  print("\n" + "="*70)
30
  print("EXECUTING: Biomarker Analyzer Agent")
31
  print("="*70)
32
+
33
  biomarkers = state['patient_biomarkers']
34
  patient_context = state.get('patient_context', {})
35
  gender = patient_context.get('gender') # None if not provided — uses non-gender-specific ranges
36
  predicted_disease = state['model_prediction']['disease']
37
+
38
  # Validate all biomarkers
39
  print(f"\nValidating {len(biomarkers)} biomarkers...")
40
  flags, alerts = self.validator.validate_all(
 
42
  gender=gender,
43
  threshold_pct=state['sop'].biomarker_analyzer_threshold
44
  )
45
+
46
  # Get disease-relevant biomarkers
47
  relevant_biomarkers = self.validator.get_disease_relevant_biomarkers(predicted_disease)
48
+
49
  # Generate summary using LLM
50
  summary = self._generate_summary(biomarkers, flags, alerts, relevant_biomarkers, predicted_disease)
51
+
52
  findings = {
53
  "biomarker_flags": [flag.model_dump() for flag in flags],
54
  "safety_alerts": [alert.model_dump() for alert in alerts],
 
62
  agent_name="Biomarker Analyzer",
63
  findings=findings
64
  )
65
+
66
  # Update state
67
  print("\nAnalysis complete:")
68
  print(f" - {len(flags)} biomarkers validated")
69
  print(f" - {len([f for f in flags if f.status != 'NORMAL'])} out-of-range values")
70
  print(f" - {len(alerts)} safety alerts generated")
71
  print(f" - {len(relevant_biomarkers)} disease-relevant biomarkers identified")
72
+
73
  return {
74
  'agent_outputs': [output],
75
  'biomarker_flags': flags,
76
  'safety_alerts': alerts,
77
  'biomarker_analysis': findings
78
  }
79
+
80
  def _generate_summary(
81
  self,
82
+ biomarkers: dict[str, float],
83
+ flags: list[BiomarkerFlag],
84
+ alerts: list,
85
+ relevant_biomarkers: list[str],
86
  disease: str
87
  ) -> str:
88
  """Generate a concise summary of biomarker findings"""
89
+
90
  # Count anomalies
91
  critical = [f for f in flags if 'CRITICAL' in f.status]
92
  high_low = [f for f in flags if f.status in ['HIGH', 'LOW']]
93
+
94
  prompt = f"""You are a medical data analyst. Provide a brief, clinical summary of these biomarker results.
95
 
96
  **Patient Context:**
 
115
  except Exception as e:
116
  print(f"Warning: LLM summary generation failed: {e}")
117
  return f"Biomarker analysis complete. {len(critical)} critical values, {len(high_low)} out-of-range values detected."
118
+
119
  def _format_key_findings(self, critical, high_low, relevant):
120
  """Format findings for LLM prompt"""
121
  findings = []
122
+
123
  if critical:
124
  findings.append("CRITICAL VALUES:")
125
  for f in critical[:3]: # Top 3
126
  findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
127
+
128
  if high_low:
129
  findings.append("\nOUT-OF-RANGE VALUES:")
130
  for f in high_low[:5]: # Top 5
131
  findings.append(f" - {f.name}: {f.value} {f.unit} ({f.status})")
132
+
133
  if relevant:
134
  findings.append(f"\nDISEASE-RELEVANT BIOMARKERS: {', '.join(relevant[:5])}")
135
+
136
  return "\n".join(findings) if findings else "All biomarkers within normal range."
137
 
138
 
src/agents/biomarker_linker.py CHANGED
@@ -3,15 +3,15 @@ MediGuard AI RAG-Helper
3
  Biomarker-Disease Linker Agent - Connects biomarker values to predicted disease
4
  """
5
 
6
- from typing import Dict, List
7
- from src.state import GuildState, AgentOutput, KeyDriver
8
  from src.llm_config import llm_config
9
- from langchain_core.prompts import ChatPromptTemplate
10
 
11
 
12
  class BiomarkerDiseaseLinkerAgent:
13
  """Agent that links specific biomarker values to the predicted disease"""
14
-
15
  def __init__(self, retriever):
16
  """
17
  Initialize with a retriever for biomarker-disease connections.
@@ -21,7 +21,7 @@ class BiomarkerDiseaseLinkerAgent:
21
  """
22
  self.retriever = retriever
23
  self.llm = llm_config.explainer
24
-
25
  def link(self, state: GuildState) -> GuildState:
26
  """
27
  Link biomarkers to disease prediction.
@@ -35,14 +35,14 @@ class BiomarkerDiseaseLinkerAgent:
35
  print("\n" + "="*70)
36
  print("EXECUTING: Biomarker-Disease Linker Agent (RAG)")
37
  print("="*70)
38
-
39
  model_prediction = state['model_prediction']
40
  disease = model_prediction['disease']
41
  biomarkers = state['patient_biomarkers']
42
-
43
  # Get biomarker analysis from previous agent
44
  biomarker_analysis = state.get('biomarker_analysis') or {}
45
-
46
  # Identify key drivers
47
  print(f"\nIdentifying key drivers for {disease}...")
48
  key_drivers, citations_missing = self._identify_key_drivers(
@@ -51,9 +51,9 @@ class BiomarkerDiseaseLinkerAgent:
51
  biomarker_analysis,
52
  state
53
  )
54
-
55
  print(f"Identified {len(key_drivers)} key biomarker drivers")
56
-
57
  # Create agent output
58
  output = AgentOutput(
59
  agent_name="Biomarker-Disease Linker",
@@ -65,45 +65,45 @@ class BiomarkerDiseaseLinkerAgent:
65
  "citations_missing": citations_missing
66
  }
67
  )
68
-
69
  # Update state
70
  print("\nBiomarker-disease linking complete")
71
-
72
  return {'agent_outputs': [output]}
73
-
74
  def _identify_key_drivers(
75
  self,
76
  disease: str,
77
- biomarkers: Dict[str, float],
78
  analysis: dict,
79
  state: GuildState
80
- ) -> tuple[List[KeyDriver], bool]:
81
  """Identify which biomarkers are driving the disease prediction"""
82
-
83
  # Get out-of-range biomarkers from analysis
84
  flags = analysis.get('biomarker_flags', [])
85
  abnormal_biomarkers = [
86
- f for f in flags
87
  if f['status'] != 'NORMAL'
88
  ]
89
-
90
  # Get disease-relevant biomarkers
91
  relevant = analysis.get('relevant_biomarkers', [])
92
-
93
  # Focus on biomarkers that are both abnormal AND disease-relevant
94
  key_biomarkers = [
95
  f for f in abnormal_biomarkers
96
  if f['name'] in relevant
97
  ]
98
-
99
  # If no key biomarkers found, use top abnormal ones
100
  if not key_biomarkers:
101
  key_biomarkers = abnormal_biomarkers[:5]
102
-
103
  print(f" Analyzing {len(key_biomarkers)} key biomarkers...")
104
-
105
  # Generate key drivers with evidence
106
- key_drivers: List[KeyDriver] = []
107
  citations_missing = False
108
  for biomarker_flag in key_biomarkers[:5]: # Top 5
109
  driver, driver_missing = self._create_key_driver(
@@ -115,7 +115,7 @@ class BiomarkerDiseaseLinkerAgent:
115
  citations_missing = citations_missing or driver_missing
116
 
117
  return key_drivers, citations_missing
118
-
119
  def _create_key_driver(
120
  self,
121
  biomarker_flag: dict,
@@ -123,15 +123,15 @@ class BiomarkerDiseaseLinkerAgent:
123
  state: GuildState
124
  ) -> tuple[KeyDriver, bool]:
125
  """Create a KeyDriver object with evidence from RAG"""
126
-
127
  name = biomarker_flag['name']
128
  value = biomarker_flag['value']
129
  unit = biomarker_flag['unit']
130
  status = biomarker_flag['status']
131
-
132
  # Retrieve evidence linking this biomarker to the disease
133
  query = f"How does {name} relate to {disease}? What does {status} {name} indicate?"
134
-
135
  citations_missing = False
136
  try:
137
  docs = self.retriever.invoke(query)
@@ -147,12 +147,12 @@ class BiomarkerDiseaseLinkerAgent:
147
  evidence_text = f"{status} {name} may be related to {disease}."
148
  contribution = "Unknown"
149
  citations_missing = True
150
-
151
  # Generate explanation using LLM
152
  explanation = self._generate_explanation(
153
  name, value, unit, status, disease, evidence_text
154
  )
155
-
156
  driver = KeyDriver(
157
  biomarker=name,
158
  value=value,
@@ -162,12 +162,12 @@ class BiomarkerDiseaseLinkerAgent:
162
  )
163
 
164
  return driver, citations_missing
165
-
166
  def _extract_evidence(self, docs: list, biomarker: str, disease: str) -> str:
167
  """Extract relevant evidence from retrieved documents"""
168
  if not docs:
169
  return f"Limited evidence available for {biomarker} in {disease}."
170
-
171
  # Combine relevant passages
172
  evidence = []
173
  for doc in docs[:2]: # Top 2 docs
@@ -175,17 +175,17 @@ class BiomarkerDiseaseLinkerAgent:
175
  # Extract sentences mentioning the biomarker
176
  sentences = content.split('.')
177
  relevant_sentences = [
178
- s.strip() for s in sentences
179
  if biomarker.lower() in s.lower() or disease.lower() in s.lower()
180
  ]
181
  evidence.extend(relevant_sentences[:2])
182
-
183
  return ". ".join(evidence[:3]) + "." if evidence else content[:300]
184
-
185
  def _estimate_contribution(self, biomarker_flag: dict, doc_count: int) -> str:
186
  """Estimate the contribution percentage (simplified)"""
187
  status = biomarker_flag['status']
188
-
189
  # Simple heuristic based on severity
190
  if 'CRITICAL' in status:
191
  base = 40
@@ -193,13 +193,13 @@ class BiomarkerDiseaseLinkerAgent:
193
  base = 25
194
  else:
195
  base = 10
196
-
197
  # Adjust based on evidence strength
198
  evidence_boost = min(doc_count * 2, 15)
199
-
200
  total = min(base + evidence_boost, 60)
201
  return f"{total}%"
202
-
203
  def _generate_explanation(
204
  self,
205
  biomarker: str,
@@ -210,7 +210,7 @@ class BiomarkerDiseaseLinkerAgent:
210
  evidence: str
211
  ) -> str:
212
  """Generate patient-friendly explanation"""
213
-
214
  prompt = f"""Explain in 1-2 sentences how this biomarker result relates to {disease}:
215
 
216
  Biomarker: {biomarker}
@@ -220,11 +220,11 @@ Status: {status}
220
  Medical Evidence: {evidence}
221
 
222
  Write in patient-friendly language, explaining what this means for the diagnosis."""
223
-
224
  try:
225
  response = self.llm.invoke(prompt)
226
  return response.content.strip()
227
- except Exception as e:
228
  return f"{biomarker} at {value} {unit} is {status}, which may be associated with {disease}."
229
 
230
 
 
3
  Biomarker-Disease Linker Agent - Connects biomarker values to predicted disease
4
  """
5
 
6
+
7
+
8
  from src.llm_config import llm_config
9
+ from src.state import AgentOutput, GuildState, KeyDriver
10
 
11
 
12
  class BiomarkerDiseaseLinkerAgent:
13
  """Agent that links specific biomarker values to the predicted disease"""
14
+
15
  def __init__(self, retriever):
16
  """
17
  Initialize with a retriever for biomarker-disease connections.
 
21
  """
22
  self.retriever = retriever
23
  self.llm = llm_config.explainer
24
+
25
  def link(self, state: GuildState) -> GuildState:
26
  """
27
  Link biomarkers to disease prediction.
 
35
  print("\n" + "="*70)
36
  print("EXECUTING: Biomarker-Disease Linker Agent (RAG)")
37
  print("="*70)
38
+
39
  model_prediction = state['model_prediction']
40
  disease = model_prediction['disease']
41
  biomarkers = state['patient_biomarkers']
42
+
43
  # Get biomarker analysis from previous agent
44
  biomarker_analysis = state.get('biomarker_analysis') or {}
45
+
46
  # Identify key drivers
47
  print(f"\nIdentifying key drivers for {disease}...")
48
  key_drivers, citations_missing = self._identify_key_drivers(
 
51
  biomarker_analysis,
52
  state
53
  )
54
+
55
  print(f"Identified {len(key_drivers)} key biomarker drivers")
56
+
57
  # Create agent output
58
  output = AgentOutput(
59
  agent_name="Biomarker-Disease Linker",
 
65
  "citations_missing": citations_missing
66
  }
67
  )
68
+
69
  # Update state
70
  print("\nBiomarker-disease linking complete")
71
+
72
  return {'agent_outputs': [output]}
73
+
74
  def _identify_key_drivers(
75
  self,
76
  disease: str,
77
+ biomarkers: dict[str, float],
78
  analysis: dict,
79
  state: GuildState
80
+ ) -> tuple[list[KeyDriver], bool]:
81
  """Identify which biomarkers are driving the disease prediction"""
82
+
83
  # Get out-of-range biomarkers from analysis
84
  flags = analysis.get('biomarker_flags', [])
85
  abnormal_biomarkers = [
86
+ f for f in flags
87
  if f['status'] != 'NORMAL'
88
  ]
89
+
90
  # Get disease-relevant biomarkers
91
  relevant = analysis.get('relevant_biomarkers', [])
92
+
93
  # Focus on biomarkers that are both abnormal AND disease-relevant
94
  key_biomarkers = [
95
  f for f in abnormal_biomarkers
96
  if f['name'] in relevant
97
  ]
98
+
99
  # If no key biomarkers found, use top abnormal ones
100
  if not key_biomarkers:
101
  key_biomarkers = abnormal_biomarkers[:5]
102
+
103
  print(f" Analyzing {len(key_biomarkers)} key biomarkers...")
104
+
105
  # Generate key drivers with evidence
106
+ key_drivers: list[KeyDriver] = []
107
  citations_missing = False
108
  for biomarker_flag in key_biomarkers[:5]: # Top 5
109
  driver, driver_missing = self._create_key_driver(
 
115
  citations_missing = citations_missing or driver_missing
116
 
117
  return key_drivers, citations_missing
118
+
119
  def _create_key_driver(
120
  self,
121
  biomarker_flag: dict,
 
123
  state: GuildState
124
  ) -> tuple[KeyDriver, bool]:
125
  """Create a KeyDriver object with evidence from RAG"""
126
+
127
  name = biomarker_flag['name']
128
  value = biomarker_flag['value']
129
  unit = biomarker_flag['unit']
130
  status = biomarker_flag['status']
131
+
132
  # Retrieve evidence linking this biomarker to the disease
133
  query = f"How does {name} relate to {disease}? What does {status} {name} indicate?"
134
+
135
  citations_missing = False
136
  try:
137
  docs = self.retriever.invoke(query)
 
147
  evidence_text = f"{status} {name} may be related to {disease}."
148
  contribution = "Unknown"
149
  citations_missing = True
150
+
151
  # Generate explanation using LLM
152
  explanation = self._generate_explanation(
153
  name, value, unit, status, disease, evidence_text
154
  )
155
+
156
  driver = KeyDriver(
157
  biomarker=name,
158
  value=value,
 
162
  )
163
 
164
  return driver, citations_missing
165
+
166
  def _extract_evidence(self, docs: list, biomarker: str, disease: str) -> str:
167
  """Extract relevant evidence from retrieved documents"""
168
  if not docs:
169
  return f"Limited evidence available for {biomarker} in {disease}."
170
+
171
  # Combine relevant passages
172
  evidence = []
173
  for doc in docs[:2]: # Top 2 docs
 
175
  # Extract sentences mentioning the biomarker
176
  sentences = content.split('.')
177
  relevant_sentences = [
178
+ s.strip() for s in sentences
179
  if biomarker.lower() in s.lower() or disease.lower() in s.lower()
180
  ]
181
  evidence.extend(relevant_sentences[:2])
182
+
183
  return ". ".join(evidence[:3]) + "." if evidence else content[:300]
184
+
185
  def _estimate_contribution(self, biomarker_flag: dict, doc_count: int) -> str:
186
  """Estimate the contribution percentage (simplified)"""
187
  status = biomarker_flag['status']
188
+
189
  # Simple heuristic based on severity
190
  if 'CRITICAL' in status:
191
  base = 40
 
193
  base = 25
194
  else:
195
  base = 10
196
+
197
  # Adjust based on evidence strength
198
  evidence_boost = min(doc_count * 2, 15)
199
+
200
  total = min(base + evidence_boost, 60)
201
  return f"{total}%"
202
+
203
  def _generate_explanation(
204
  self,
205
  biomarker: str,
 
210
  evidence: str
211
  ) -> str:
212
  """Generate patient-friendly explanation"""
213
+
214
  prompt = f"""Explain in 1-2 sentences how this biomarker result relates to {disease}:
215
 
216
  Biomarker: {biomarker}
 
220
  Medical Evidence: {evidence}
221
 
222
  Write in patient-friendly language, explaining what this means for the diagnosis."""
223
+
224
  try:
225
  response = self.llm.invoke(prompt)
226
  return response.content.strip()
227
+ except Exception:
228
  return f"{biomarker} at {value} {unit} is {status}, which may be associated with {disease}."
229
 
230
 
src/agents/clinical_guidelines.py CHANGED
@@ -4,15 +4,16 @@ Clinical Guidelines Agent - Retrieves evidence-based recommendations
4
  """
5
 
6
  from pathlib import Path
7
- from typing import List
8
- from src.state import GuildState, AgentOutput
9
- from src.llm_config import llm_config
10
  from langchain_core.prompts import ChatPromptTemplate
11
 
 
 
 
12
 
13
  class ClinicalGuidelinesAgent:
14
  """Agent that retrieves clinical guidelines and recommendations using RAG"""
15
-
16
  def __init__(self, retriever):
17
  """
18
  Initialize with a retriever for clinical guidelines.
@@ -22,7 +23,7 @@ class ClinicalGuidelinesAgent:
22
  """
23
  self.retriever = retriever
24
  self.llm = llm_config.explainer
25
-
26
  def recommend(self, state: GuildState) -> GuildState:
27
  """
28
  Retrieve clinical guidelines and generate recommendations.
@@ -36,25 +37,25 @@ class ClinicalGuidelinesAgent:
36
  print("\n" + "="*70)
37
  print("EXECUTING: Clinical Guidelines Agent (RAG)")
38
  print("="*70)
39
-
40
  model_prediction = state['model_prediction']
41
  disease = model_prediction['disease']
42
  confidence = model_prediction['confidence']
43
-
44
  # Get biomarker analysis
45
  biomarker_analysis = state.get('biomarker_analysis') or {}
46
  safety_alerts = biomarker_analysis.get('safety_alerts', [])
47
-
48
  # Retrieve guidelines
49
  print(f"\nRetrieving clinical guidelines for {disease}...")
50
-
51
  query = f"""What are the clinical practice guidelines for managing {disease}?
52
  Include lifestyle modifications, monitoring recommendations, and when to seek medical care."""
53
-
54
  docs = self.retriever.invoke(query)
55
-
56
  print(f"Retrieved {len(docs)} guideline documents")
57
-
58
  # Generate recommendations
59
  if state['sop'].require_pdf_citations and not docs:
60
  recommendations = {
@@ -73,7 +74,7 @@ class ClinicalGuidelinesAgent:
73
  confidence,
74
  state
75
  )
76
-
77
  # Create agent output
78
  output = AgentOutput(
79
  agent_name="Clinical Guidelines",
@@ -87,15 +88,15 @@ class ClinicalGuidelinesAgent:
87
  "citations_missing": state['sop'].require_pdf_citations and not docs
88
  }
89
  )
90
-
91
  # Update state
92
  print("\nRecommendations generated")
93
  print(f" - Immediate actions: {len(recommendations['immediate_actions'])}")
94
  print(f" - Lifestyle changes: {len(recommendations['lifestyle_changes'])}")
95
  print(f" - Monitoring recommendations: {len(recommendations['monitoring'])}")
96
-
97
  return {'agent_outputs': [output]}
98
-
99
  def _generate_recommendations(
100
  self,
101
  disease: str,
@@ -105,20 +106,20 @@ class ClinicalGuidelinesAgent:
105
  state: GuildState
106
  ) -> dict:
107
  """Generate structured recommendations using LLM and guidelines"""
108
-
109
  # Format retrieved guidelines
110
  guidelines_context = "\n\n---\n\n".join([
111
  f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
112
  for doc in docs
113
  ])
114
-
115
  # Build safety context
116
  safety_context = ""
117
  if safety_alerts:
118
  safety_context = "\n**CRITICAL SAFETY ALERTS:**\n"
119
  for alert in safety_alerts[:3]:
120
  safety_context += f"- {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}\n"
121
-
122
  prompt = ChatPromptTemplate.from_messages([
123
  ("system", """You are a clinical decision support system providing evidence-based recommendations.
124
  Based on clinical practice guidelines, provide actionable recommendations for patient self-assessment.
@@ -139,9 +140,9 @@ class ClinicalGuidelinesAgent:
139
 
140
  Please provide structured recommendations for patient self-assessment.""")
141
  ])
142
-
143
  chain = prompt | self.llm
144
-
145
  try:
146
  response = chain.invoke({
147
  "disease": disease,
@@ -149,18 +150,18 @@ class ClinicalGuidelinesAgent:
149
  "safety_context": safety_context,
150
  "guidelines": guidelines_context
151
  })
152
-
153
  recommendations = self._parse_recommendations(response.content)
154
-
155
  except Exception as e:
156
  print(f"Warning: LLM recommendation generation failed: {e}")
157
  recommendations = self._get_default_recommendations(disease, safety_alerts)
158
-
159
  # Add citations
160
  recommendations['citations'] = self._extract_citations(docs)
161
-
162
  return recommendations
163
-
164
  def _parse_recommendations(self, content: str) -> dict:
165
  """Parse LLM response into structured recommendations"""
166
  recommendations = {
@@ -168,14 +169,14 @@ class ClinicalGuidelinesAgent:
168
  "lifestyle_changes": [],
169
  "monitoring": []
170
  }
171
-
172
  current_section = None
173
  lines = content.split('\n')
174
-
175
  for line in lines:
176
  line_stripped = line.strip()
177
  line_upper = line_stripped.upper()
178
-
179
  # Detect section headers
180
  if 'IMMEDIATE' in line_upper or 'URGENT' in line_upper:
181
  current_section = 'immediate_actions'
@@ -189,16 +190,16 @@ class ClinicalGuidelinesAgent:
189
  cleaned = line_stripped.lstrip('•-*0123456789. ')
190
  if cleaned and len(cleaned) > 10: # Minimum length filter
191
  recommendations[current_section].append(cleaned)
192
-
193
  # If parsing failed, create default structure
194
  if not any(recommendations.values()):
195
  sentences = content.split('.')
196
  recommendations['immediate_actions'] = [s.strip() for s in sentences[:2] if s.strip()]
197
  recommendations['lifestyle_changes'] = [s.strip() for s in sentences[2:4] if s.strip()]
198
  recommendations['monitoring'] = [s.strip() for s in sentences[4:6] if s.strip()]
199
-
200
  return recommendations
201
-
202
  def _get_default_recommendations(self, disease: str, safety_alerts: list) -> dict:
203
  """Provide default recommendations if LLM fails"""
204
  recommendations = {
@@ -206,7 +207,7 @@ class ClinicalGuidelinesAgent:
206
  "lifestyle_changes": [],
207
  "monitoring": []
208
  }
209
-
210
  # Add safety-based immediate actions
211
  if safety_alerts:
212
  recommendations['immediate_actions'].append(
@@ -219,36 +220,36 @@ class ClinicalGuidelinesAgent:
219
  recommendations['immediate_actions'].append(
220
  f"Schedule appointment with healthcare provider to discuss {disease} findings"
221
  )
222
-
223
  # Generic lifestyle changes
224
  recommendations['lifestyle_changes'].extend([
225
  "Follow a balanced, nutrient-rich diet as recommended by healthcare provider",
226
  "Maintain regular physical activity appropriate for your health status",
227
  "Track symptoms and biomarker trends over time"
228
  ])
229
-
230
  # Generic monitoring
231
  recommendations['monitoring'].extend([
232
  f"Regular monitoring of {disease}-related biomarkers as advised by physician",
233
  "Keep a health journal tracking symptoms, diet, and activities",
234
  "Schedule follow-up appointments as recommended"
235
  ])
236
-
237
  return recommendations
238
-
239
- def _extract_citations(self, docs: list) -> List[str]:
240
  """Extract citations from retrieved guideline documents"""
241
  citations = []
242
-
243
  for doc in docs:
244
  source = doc.metadata.get('source', 'Unknown')
245
-
246
  # Clean up source path
247
  if '\\' in source or '/' in source:
248
  source = Path(source).name
249
-
250
  citations.append(source)
251
-
252
  return list(set(citations)) # Remove duplicates
253
 
254
 
 
4
  """
5
 
6
  from pathlib import Path
7
+
 
 
8
  from langchain_core.prompts import ChatPromptTemplate
9
 
10
+ from src.llm_config import llm_config
11
+ from src.state import AgentOutput, GuildState
12
+
13
 
14
  class ClinicalGuidelinesAgent:
15
  """Agent that retrieves clinical guidelines and recommendations using RAG"""
16
+
17
  def __init__(self, retriever):
18
  """
19
  Initialize with a retriever for clinical guidelines.
 
23
  """
24
  self.retriever = retriever
25
  self.llm = llm_config.explainer
26
+
27
  def recommend(self, state: GuildState) -> GuildState:
28
  """
29
  Retrieve clinical guidelines and generate recommendations.
 
37
  print("\n" + "="*70)
38
  print("EXECUTING: Clinical Guidelines Agent (RAG)")
39
  print("="*70)
40
+
41
  model_prediction = state['model_prediction']
42
  disease = model_prediction['disease']
43
  confidence = model_prediction['confidence']
44
+
45
  # Get biomarker analysis
46
  biomarker_analysis = state.get('biomarker_analysis') or {}
47
  safety_alerts = biomarker_analysis.get('safety_alerts', [])
48
+
49
  # Retrieve guidelines
50
  print(f"\nRetrieving clinical guidelines for {disease}...")
51
+
52
  query = f"""What are the clinical practice guidelines for managing {disease}?
53
  Include lifestyle modifications, monitoring recommendations, and when to seek medical care."""
54
+
55
  docs = self.retriever.invoke(query)
56
+
57
  print(f"Retrieved {len(docs)} guideline documents")
58
+
59
  # Generate recommendations
60
  if state['sop'].require_pdf_citations and not docs:
61
  recommendations = {
 
74
  confidence,
75
  state
76
  )
77
+
78
  # Create agent output
79
  output = AgentOutput(
80
  agent_name="Clinical Guidelines",
 
88
  "citations_missing": state['sop'].require_pdf_citations and not docs
89
  }
90
  )
91
+
92
  # Update state
93
  print("\nRecommendations generated")
94
  print(f" - Immediate actions: {len(recommendations['immediate_actions'])}")
95
  print(f" - Lifestyle changes: {len(recommendations['lifestyle_changes'])}")
96
  print(f" - Monitoring recommendations: {len(recommendations['monitoring'])}")
97
+
98
  return {'agent_outputs': [output]}
99
+
100
  def _generate_recommendations(
101
  self,
102
  disease: str,
 
106
  state: GuildState
107
  ) -> dict:
108
  """Generate structured recommendations using LLM and guidelines"""
109
+
110
  # Format retrieved guidelines
111
  guidelines_context = "\n\n---\n\n".join([
112
  f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
113
  for doc in docs
114
  ])
115
+
116
  # Build safety context
117
  safety_context = ""
118
  if safety_alerts:
119
  safety_context = "\n**CRITICAL SAFETY ALERTS:**\n"
120
  for alert in safety_alerts[:3]:
121
  safety_context += f"- {alert.get('biomarker', 'Unknown')}: {alert.get('message', '')}\n"
122
+
123
  prompt = ChatPromptTemplate.from_messages([
124
  ("system", """You are a clinical decision support system providing evidence-based recommendations.
125
  Based on clinical practice guidelines, provide actionable recommendations for patient self-assessment.
 
140
 
141
  Please provide structured recommendations for patient self-assessment.""")
142
  ])
143
+
144
  chain = prompt | self.llm
145
+
146
  try:
147
  response = chain.invoke({
148
  "disease": disease,
 
150
  "safety_context": safety_context,
151
  "guidelines": guidelines_context
152
  })
153
+
154
  recommendations = self._parse_recommendations(response.content)
155
+
156
  except Exception as e:
157
  print(f"Warning: LLM recommendation generation failed: {e}")
158
  recommendations = self._get_default_recommendations(disease, safety_alerts)
159
+
160
  # Add citations
161
  recommendations['citations'] = self._extract_citations(docs)
162
+
163
  return recommendations
164
+
165
  def _parse_recommendations(self, content: str) -> dict:
166
  """Parse LLM response into structured recommendations"""
167
  recommendations = {
 
169
  "lifestyle_changes": [],
170
  "monitoring": []
171
  }
172
+
173
  current_section = None
174
  lines = content.split('\n')
175
+
176
  for line in lines:
177
  line_stripped = line.strip()
178
  line_upper = line_stripped.upper()
179
+
180
  # Detect section headers
181
  if 'IMMEDIATE' in line_upper or 'URGENT' in line_upper:
182
  current_section = 'immediate_actions'
 
190
  cleaned = line_stripped.lstrip('•-*0123456789. ')
191
  if cleaned and len(cleaned) > 10: # Minimum length filter
192
  recommendations[current_section].append(cleaned)
193
+
194
  # If parsing failed, create default structure
195
  if not any(recommendations.values()):
196
  sentences = content.split('.')
197
  recommendations['immediate_actions'] = [s.strip() for s in sentences[:2] if s.strip()]
198
  recommendations['lifestyle_changes'] = [s.strip() for s in sentences[2:4] if s.strip()]
199
  recommendations['monitoring'] = [s.strip() for s in sentences[4:6] if s.strip()]
200
+
201
  return recommendations
202
+
203
  def _get_default_recommendations(self, disease: str, safety_alerts: list) -> dict:
204
  """Provide default recommendations if LLM fails"""
205
  recommendations = {
 
207
  "lifestyle_changes": [],
208
  "monitoring": []
209
  }
210
+
211
  # Add safety-based immediate actions
212
  if safety_alerts:
213
  recommendations['immediate_actions'].append(
 
220
  recommendations['immediate_actions'].append(
221
  f"Schedule appointment with healthcare provider to discuss {disease} findings"
222
  )
223
+
224
  # Generic lifestyle changes
225
  recommendations['lifestyle_changes'].extend([
226
  "Follow a balanced, nutrient-rich diet as recommended by healthcare provider",
227
  "Maintain regular physical activity appropriate for your health status",
228
  "Track symptoms and biomarker trends over time"
229
  ])
230
+
231
  # Generic monitoring
232
  recommendations['monitoring'].extend([
233
  f"Regular monitoring of {disease}-related biomarkers as advised by physician",
234
  "Keep a health journal tracking symptoms, diet, and activities",
235
  "Schedule follow-up appointments as recommended"
236
  ])
237
+
238
  return recommendations
239
+
240
+ def _extract_citations(self, docs: list) -> list[str]:
241
  """Extract citations from retrieved guideline documents"""
242
  citations = []
243
+
244
  for doc in docs:
245
  source = doc.metadata.get('source', 'Unknown')
246
+
247
  # Clean up source path
248
  if '\\' in source or '/' in source:
249
  source = Path(source).name
250
+
251
  citations.append(source)
252
+
253
  return list(set(citations)) # Remove duplicates
254
 
255
 
src/agents/confidence_assessor.py CHANGED
@@ -3,19 +3,19 @@ MediGuard AI RAG-Helper
3
  Confidence Assessor Agent - Evaluates prediction reliability
4
  """
5
 
6
- from typing import Any, Dict, List
7
- from src.state import GuildState, AgentOutput
8
  from src.biomarker_validator import BiomarkerValidator
9
  from src.llm_config import llm_config
10
- from langchain_core.prompts import ChatPromptTemplate
11
 
12
 
13
  class ConfidenceAssessorAgent:
14
  """Agent that assesses the reliability and limitations of the prediction"""
15
-
16
  def __init__(self):
17
  self.llm = llm_config.analyzer
18
-
19
  def assess(self, state: GuildState) -> GuildState:
20
  """
21
  Assess prediction confidence and identify limitations.
@@ -29,41 +29,41 @@ class ConfidenceAssessorAgent:
29
  print("\n" + "="*70)
30
  print("EXECUTING: Confidence Assessor Agent")
31
  print("="*70)
32
-
33
  model_prediction = state['model_prediction']
34
  disease = model_prediction['disease']
35
  ml_confidence = model_prediction['confidence']
36
  probabilities = model_prediction.get('probabilities', {})
37
  biomarkers = state['patient_biomarkers']
38
-
39
  # Collect previous agent findings
40
  biomarker_analysis = state.get('biomarker_analysis') or {}
41
  disease_explanation = self._get_agent_findings(state, "Disease Explainer")
42
  linker_findings = self._get_agent_findings(state, "Biomarker-Disease Linker")
43
-
44
  print(f"\nAssessing confidence for {disease} prediction...")
45
-
46
  # Evaluate evidence strength
47
  evidence_strength = self._evaluate_evidence_strength(
48
  biomarker_analysis,
49
  disease_explanation,
50
  linker_findings
51
  )
52
-
53
  # Identify limitations
54
  limitations = self._identify_limitations(
55
  biomarkers,
56
  biomarker_analysis,
57
  probabilities
58
  )
59
-
60
  # Calculate aggregate reliability
61
  reliability = self._calculate_reliability(
62
  ml_confidence,
63
  evidence_strength,
64
  len(limitations)
65
  )
66
-
67
  # Generate assessment summary
68
  assessment_summary = self._generate_assessment(
69
  disease,
@@ -72,7 +72,7 @@ class ConfidenceAssessorAgent:
72
  evidence_strength,
73
  limitations
74
  )
75
-
76
  # Create agent output
77
  output = AgentOutput(
78
  agent_name="Confidence Assessor",
@@ -86,22 +86,22 @@ class ConfidenceAssessorAgent:
86
  "alternative_diagnoses": self._get_alternatives(probabilities)
87
  }
88
  )
89
-
90
  # Update state
91
  print("\nConfidence assessment complete")
92
  print(f" - Prediction reliability: {reliability}")
93
  print(f" - Evidence strength: {evidence_strength}")
94
  print(f" - Limitations identified: {len(limitations)}")
95
-
96
  return {'agent_outputs': [output]}
97
-
98
  def _get_agent_findings(self, state: GuildState, agent_name: str) -> dict:
99
  """Extract findings from a specific agent"""
100
  for output in state.get('agent_outputs', []):
101
  if output.agent_name == agent_name:
102
  return output.findings
103
  return {}
104
-
105
  def _evaluate_evidence_strength(
106
  self,
107
  biomarker_analysis: dict,
@@ -109,10 +109,10 @@ class ConfidenceAssessorAgent:
109
  linker_findings: dict
110
  ) -> str:
111
  """Evaluate the strength of supporting evidence"""
112
-
113
  score = 0
114
  max_score = 5
115
-
116
  # Check biomarker validation quality
117
  flags = biomarker_analysis.get('biomarker_flags', [])
118
  abnormal_count = len([f for f in flags if f.get('status') != 'NORMAL'])
@@ -120,18 +120,18 @@ class ConfidenceAssessorAgent:
120
  score += 1
121
  if abnormal_count >= 5:
122
  score += 1
123
-
124
  # Check disease explanation quality
125
  if disease_explanation.get('retrieval_quality', 0) >= 3:
126
  score += 1
127
-
128
  # Check biomarker-disease linking
129
  key_drivers = linker_findings.get('key_drivers', [])
130
  if len(key_drivers) >= 2:
131
  score += 1
132
  if len(key_drivers) >= 4:
133
  score += 1
134
-
135
  # Map score to categorical rating
136
  if score >= 4:
137
  return "STRONG"
@@ -139,22 +139,22 @@ class ConfidenceAssessorAgent:
139
  return "MODERATE"
140
  else:
141
  return "WEAK"
142
-
143
  def _identify_limitations(
144
  self,
145
- biomarkers: Dict[str, float],
146
  biomarker_analysis: dict,
147
- probabilities: Dict[str, float]
148
- ) -> List[str]:
149
  """Identify limitations and uncertainties"""
150
  limitations = []
151
-
152
  # Check for missing biomarkers
153
  expected_biomarkers = BiomarkerValidator().expected_biomarker_count()
154
  if len(biomarkers) < expected_biomarkers:
155
  missing = expected_biomarkers - len(biomarkers)
156
  limitations.append(f"Missing data: {missing} biomarker(s) not provided")
157
-
158
  # Check for close alternative predictions
159
  sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
160
  if len(sorted_probs) >= 2:
@@ -164,7 +164,7 @@ class ConfidenceAssessorAgent:
164
  limitations.append(
165
  f"Differential diagnosis: {top2} also possible ({prob2:.1%} probability)"
166
  )
167
-
168
  # Check for normal biomarkers despite prediction
169
  flags = biomarker_analysis.get('biomarker_flags', [])
170
  relevant = biomarker_analysis.get('relevant_biomarkers', [])
@@ -174,18 +174,18 @@ class ConfidenceAssessorAgent:
174
  ]
175
  if len(normal_relevant) >= 2:
176
  limitations.append(
177
- f"Some disease-relevant biomarkers are within normal range"
178
  )
179
-
180
  # Check for safety alerts (indicates complexity)
181
  alerts = biomarker_analysis.get('safety_alerts', [])
182
  if len(alerts) >= 2:
183
  limitations.append(
184
  "Multiple critical values detected; professional evaluation essential"
185
  )
186
-
187
  return limitations
188
-
189
  def _calculate_reliability(
190
  self,
191
  ml_confidence: float,
@@ -193,9 +193,9 @@ class ConfidenceAssessorAgent:
193
  limitation_count: int
194
  ) -> str:
195
  """Calculate overall prediction reliability"""
196
-
197
  score = 0
198
-
199
  # ML confidence contribution
200
  if ml_confidence >= 0.8:
201
  score += 3
@@ -203,7 +203,7 @@ class ConfidenceAssessorAgent:
203
  score += 2
204
  elif ml_confidence >= 0.4:
205
  score += 1
206
-
207
  # Evidence strength contribution
208
  if evidence_strength == "STRONG":
209
  score += 3
@@ -211,10 +211,10 @@ class ConfidenceAssessorAgent:
211
  score += 2
212
  else:
213
  score += 1
214
-
215
  # Limitation penalty
216
  score -= min(limitation_count, 3)
217
-
218
  # Map to categorical
219
  if score >= 5:
220
  return "HIGH"
@@ -222,17 +222,17 @@ class ConfidenceAssessorAgent:
222
  return "MODERATE"
223
  else:
224
  return "LOW"
225
-
226
  def _generate_assessment(
227
  self,
228
  disease: str,
229
  ml_confidence: float,
230
  reliability: str,
231
  evidence_strength: str,
232
- limitations: List[str]
233
  ) -> str:
234
  """Generate human-readable assessment summary"""
235
-
236
  prompt = f"""As a medical AI assessment system, provide a brief confidence statement about this prediction:
237
 
238
  Disease Predicted: {disease}
@@ -254,7 +254,7 @@ Be honest about uncertainty. Patient safety is paramount."""
254
  except Exception as e:
255
  print(f"Warning: Assessment generation failed: {e}")
256
  return f"The {disease} prediction has {reliability.lower()} reliability based on available data. Professional medical evaluation is strongly recommended for accurate diagnosis."
257
-
258
  def _get_recommendation(self, reliability: str) -> str:
259
  """Get action recommendation based on reliability"""
260
  if reliability == "HIGH":
@@ -263,11 +263,11 @@ Be honest about uncertainty. Patient safety is paramount."""
263
  return "Moderate confidence prediction. Medical consultation recommended for professional evaluation and additional testing if needed."
264
  else:
265
  return "Low confidence prediction. Professional medical assessment essential. Additional tests may be required for accurate diagnosis."
266
-
267
- def _get_alternatives(self, probabilities: Dict[str, float]) -> List[Dict[str, Any]]:
268
  """Get alternative diagnoses to consider"""
269
  sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
270
-
271
  alternatives = []
272
  for disease, prob in sorted_probs[1:4]: # Top 3 alternatives
273
  if prob > 0.05: # Only significant alternatives
@@ -276,7 +276,7 @@ Be honest about uncertainty. Patient safety is paramount."""
276
  "probability": prob,
277
  "note": "Consider discussing with healthcare provider"
278
  })
279
-
280
  return alternatives
281
 
282
 
 
3
  Confidence Assessor Agent - Evaluates prediction reliability
4
  """
5
 
6
+ from typing import Any
7
+
8
  from src.biomarker_validator import BiomarkerValidator
9
  from src.llm_config import llm_config
10
+ from src.state import AgentOutput, GuildState
11
 
12
 
13
  class ConfidenceAssessorAgent:
14
  """Agent that assesses the reliability and limitations of the prediction"""
15
+
16
  def __init__(self):
17
  self.llm = llm_config.analyzer
18
+
19
  def assess(self, state: GuildState) -> GuildState:
20
  """
21
  Assess prediction confidence and identify limitations.
 
29
  print("\n" + "="*70)
30
  print("EXECUTING: Confidence Assessor Agent")
31
  print("="*70)
32
+
33
  model_prediction = state['model_prediction']
34
  disease = model_prediction['disease']
35
  ml_confidence = model_prediction['confidence']
36
  probabilities = model_prediction.get('probabilities', {})
37
  biomarkers = state['patient_biomarkers']
38
+
39
  # Collect previous agent findings
40
  biomarker_analysis = state.get('biomarker_analysis') or {}
41
  disease_explanation = self._get_agent_findings(state, "Disease Explainer")
42
  linker_findings = self._get_agent_findings(state, "Biomarker-Disease Linker")
43
+
44
  print(f"\nAssessing confidence for {disease} prediction...")
45
+
46
  # Evaluate evidence strength
47
  evidence_strength = self._evaluate_evidence_strength(
48
  biomarker_analysis,
49
  disease_explanation,
50
  linker_findings
51
  )
52
+
53
  # Identify limitations
54
  limitations = self._identify_limitations(
55
  biomarkers,
56
  biomarker_analysis,
57
  probabilities
58
  )
59
+
60
  # Calculate aggregate reliability
61
  reliability = self._calculate_reliability(
62
  ml_confidence,
63
  evidence_strength,
64
  len(limitations)
65
  )
66
+
67
  # Generate assessment summary
68
  assessment_summary = self._generate_assessment(
69
  disease,
 
72
  evidence_strength,
73
  limitations
74
  )
75
+
76
  # Create agent output
77
  output = AgentOutput(
78
  agent_name="Confidence Assessor",
 
86
  "alternative_diagnoses": self._get_alternatives(probabilities)
87
  }
88
  )
89
+
90
  # Update state
91
  print("\nConfidence assessment complete")
92
  print(f" - Prediction reliability: {reliability}")
93
  print(f" - Evidence strength: {evidence_strength}")
94
  print(f" - Limitations identified: {len(limitations)}")
95
+
96
  return {'agent_outputs': [output]}
97
+
98
  def _get_agent_findings(self, state: GuildState, agent_name: str) -> dict:
99
  """Extract findings from a specific agent"""
100
  for output in state.get('agent_outputs', []):
101
  if output.agent_name == agent_name:
102
  return output.findings
103
  return {}
104
+
105
  def _evaluate_evidence_strength(
106
  self,
107
  biomarker_analysis: dict,
 
109
  linker_findings: dict
110
  ) -> str:
111
  """Evaluate the strength of supporting evidence"""
112
+
113
  score = 0
114
  max_score = 5
115
+
116
  # Check biomarker validation quality
117
  flags = biomarker_analysis.get('biomarker_flags', [])
118
  abnormal_count = len([f for f in flags if f.get('status') != 'NORMAL'])
 
120
  score += 1
121
  if abnormal_count >= 5:
122
  score += 1
123
+
124
  # Check disease explanation quality
125
  if disease_explanation.get('retrieval_quality', 0) >= 3:
126
  score += 1
127
+
128
  # Check biomarker-disease linking
129
  key_drivers = linker_findings.get('key_drivers', [])
130
  if len(key_drivers) >= 2:
131
  score += 1
132
  if len(key_drivers) >= 4:
133
  score += 1
134
+
135
  # Map score to categorical rating
136
  if score >= 4:
137
  return "STRONG"
 
139
  return "MODERATE"
140
  else:
141
  return "WEAK"
142
+
143
  def _identify_limitations(
144
  self,
145
+ biomarkers: dict[str, float],
146
  biomarker_analysis: dict,
147
+ probabilities: dict[str, float]
148
+ ) -> list[str]:
149
  """Identify limitations and uncertainties"""
150
  limitations = []
151
+
152
  # Check for missing biomarkers
153
  expected_biomarkers = BiomarkerValidator().expected_biomarker_count()
154
  if len(biomarkers) < expected_biomarkers:
155
  missing = expected_biomarkers - len(biomarkers)
156
  limitations.append(f"Missing data: {missing} biomarker(s) not provided")
157
+
158
  # Check for close alternative predictions
159
  sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
160
  if len(sorted_probs) >= 2:
 
164
  limitations.append(
165
  f"Differential diagnosis: {top2} also possible ({prob2:.1%} probability)"
166
  )
167
+
168
  # Check for normal biomarkers despite prediction
169
  flags = biomarker_analysis.get('biomarker_flags', [])
170
  relevant = biomarker_analysis.get('relevant_biomarkers', [])
 
174
  ]
175
  if len(normal_relevant) >= 2:
176
  limitations.append(
177
+ "Some disease-relevant biomarkers are within normal range"
178
  )
179
+
180
  # Check for safety alerts (indicates complexity)
181
  alerts = biomarker_analysis.get('safety_alerts', [])
182
  if len(alerts) >= 2:
183
  limitations.append(
184
  "Multiple critical values detected; professional evaluation essential"
185
  )
186
+
187
  return limitations
188
+
189
  def _calculate_reliability(
190
  self,
191
  ml_confidence: float,
 
193
  limitation_count: int
194
  ) -> str:
195
  """Calculate overall prediction reliability"""
196
+
197
  score = 0
198
+
199
  # ML confidence contribution
200
  if ml_confidence >= 0.8:
201
  score += 3
 
203
  score += 2
204
  elif ml_confidence >= 0.4:
205
  score += 1
206
+
207
  # Evidence strength contribution
208
  if evidence_strength == "STRONG":
209
  score += 3
 
211
  score += 2
212
  else:
213
  score += 1
214
+
215
  # Limitation penalty
216
  score -= min(limitation_count, 3)
217
+
218
  # Map to categorical
219
  if score >= 5:
220
  return "HIGH"
 
222
  return "MODERATE"
223
  else:
224
  return "LOW"
225
+
226
  def _generate_assessment(
227
  self,
228
  disease: str,
229
  ml_confidence: float,
230
  reliability: str,
231
  evidence_strength: str,
232
+ limitations: list[str]
233
  ) -> str:
234
  """Generate human-readable assessment summary"""
235
+
236
  prompt = f"""As a medical AI assessment system, provide a brief confidence statement about this prediction:
237
 
238
  Disease Predicted: {disease}
 
254
  except Exception as e:
255
  print(f"Warning: Assessment generation failed: {e}")
256
  return f"The {disease} prediction has {reliability.lower()} reliability based on available data. Professional medical evaluation is strongly recommended for accurate diagnosis."
257
+
258
  def _get_recommendation(self, reliability: str) -> str:
259
  """Get action recommendation based on reliability"""
260
  if reliability == "HIGH":
 
263
  return "Moderate confidence prediction. Medical consultation recommended for professional evaluation and additional testing if needed."
264
  else:
265
  return "Low confidence prediction. Professional medical assessment essential. Additional tests may be required for accurate diagnosis."
266
+
267
+ def _get_alternatives(self, probabilities: dict[str, float]) -> list[dict[str, Any]]:
268
  """Get alternative diagnoses to consider"""
269
  sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
270
+
271
  alternatives = []
272
  for disease, prob in sorted_probs[1:4]: # Top 3 alternatives
273
  if prob > 0.05: # Only significant alternatives
 
276
  "probability": prob,
277
  "note": "Consider discussing with healthcare provider"
278
  })
279
+
280
  return alternatives
281
 
282
 
src/agents/disease_explainer.py CHANGED
@@ -4,14 +4,16 @@ Disease Explainer Agent - Retrieves disease pathophysiology from medical PDFs
4
  """
5
 
6
  from pathlib import Path
7
- from src.state import GuildState, AgentOutput
8
- from src.llm_config import llm_config
9
  from langchain_core.prompts import ChatPromptTemplate
10
 
 
 
 
11
 
12
  class DiseaseExplainerAgent:
13
  """Agent that retrieves and explains disease mechanisms using RAG"""
14
-
15
  def __init__(self, retriever):
16
  """
17
  Initialize with a retriever for medical PDFs.
@@ -21,7 +23,7 @@ class DiseaseExplainerAgent:
21
  """
22
  self.retriever = retriever
23
  self.llm = llm_config.explainer
24
-
25
  def explain(self, state: GuildState) -> GuildState:
26
  """
27
  Retrieve and explain disease pathophysiology.
@@ -35,23 +37,23 @@ class DiseaseExplainerAgent:
35
  print("\n" + "="*70)
36
  print("EXECUTING: Disease Explainer Agent (RAG)")
37
  print("="*70)
38
-
39
  model_prediction = state['model_prediction']
40
  disease = model_prediction['disease']
41
  confidence = model_prediction['confidence']
42
-
43
  # Configure retrieval based on SOP — create a copy to avoid mutating shared retriever
44
  retrieval_k = state['sop'].disease_explainer_k
45
  original_search_kwargs = dict(self.retriever.search_kwargs)
46
  self.retriever.search_kwargs = {**original_search_kwargs, 'k': retrieval_k}
47
-
48
  # Retrieve relevant documents
49
  print(f"\nRetrieving information about: {disease}")
50
  print(f"Retrieval k={state['sop'].disease_explainer_k}")
51
-
52
  query = f"""What is {disease}? Explain the pathophysiology, diagnostic criteria,
53
  and clinical presentation. Focus on mechanisms relevant to blood biomarkers."""
54
-
55
  try:
56
  docs = self.retriever.invoke(query)
57
  finally:
@@ -87,13 +89,13 @@ class DiseaseExplainerAgent:
87
  print(" - Pathophysiology: insufficient evidence")
88
  print(" - Citations: 0 sources")
89
  return {'agent_outputs': [output]}
90
-
91
  # Generate explanation
92
  explanation = self._generate_explanation(disease, docs, confidence)
93
-
94
  # Extract citations
95
  citations = self._extract_citations(docs)
96
-
97
  # Create agent output
98
  output = AgentOutput(
99
  agent_name="Disease Explainer",
@@ -109,23 +111,23 @@ class DiseaseExplainerAgent:
109
  "citations_missing": False
110
  }
111
  )
112
-
113
  # Update state
114
  print("\nDisease explanation generated")
115
  print(f" - Pathophysiology: {len(explanation['pathophysiology'])} chars")
116
  print(f" - Citations: {len(citations)} sources")
117
-
118
  return {'agent_outputs': [output]}
119
-
120
  def _generate_explanation(self, disease: str, docs: list, confidence: float) -> dict:
121
  """Generate structured disease explanation using LLM and retrieved docs"""
122
-
123
  # Format retrieved context
124
  context = "\n\n---\n\n".join([
125
  f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
126
  for doc in docs
127
  ])
128
-
129
  prompt = ChatPromptTemplate.from_messages([
130
  ("system", """You are a medical expert explaining diseases for patient self-assessment.
131
  Based on the provided medical literature, explain the disease in clear, accessible language.
@@ -144,20 +146,20 @@ class DiseaseExplainerAgent:
144
 
145
  Please provide a structured explanation.""")
146
  ])
147
-
148
  chain = prompt | self.llm
149
-
150
  try:
151
  response = chain.invoke({
152
  "disease": disease,
153
  "confidence": confidence,
154
  "context": context
155
  })
156
-
157
  # Parse structured response
158
  content = response.content
159
  explanation = self._parse_explanation(content)
160
-
161
  except Exception as e:
162
  print(f"Warning: LLM explanation generation failed: {e}")
163
  explanation = {
@@ -166,9 +168,9 @@ class DiseaseExplainerAgent:
166
  "clinical_presentation": "Clinical presentation varies by individual.",
167
  "summary": f"{disease} detected with {confidence:.1%} confidence. Consult healthcare provider."
168
  }
169
-
170
  return explanation
171
-
172
  def _parse_explanation(self, content: str) -> dict:
173
  """Parse LLM response into structured sections"""
174
  sections = {
@@ -177,14 +179,14 @@ class DiseaseExplainerAgent:
177
  "clinical_presentation": "",
178
  "summary": ""
179
  }
180
-
181
  # Simple parsing logic
182
  current_section = None
183
  lines = content.split('\n')
184
-
185
  for line in lines:
186
  line_upper = line.upper().strip()
187
-
188
  if 'PATHOPHYSIOLOGY' in line_upper:
189
  current_section = 'pathophysiology'
190
  elif 'DIAGNOSTIC' in line_upper:
@@ -195,31 +197,31 @@ class DiseaseExplainerAgent:
195
  current_section = 'summary'
196
  elif current_section and line.strip():
197
  sections[current_section] += line + "\n"
198
-
199
  # If parsing failed, use full content as summary
200
  if not any(sections.values()):
201
  sections['summary'] = content[:500]
202
-
203
  return sections
204
-
205
  def _extract_citations(self, docs: list) -> list:
206
  """Extract citations from retrieved documents"""
207
  citations = []
208
-
209
  for doc in docs:
210
  source = doc.metadata.get('source', 'Unknown')
211
  page = doc.metadata.get('page', 'N/A')
212
-
213
  # Clean up source path
214
  if '\\' in source or '/' in source:
215
  source = Path(source).name
216
-
217
  citation = f"{source}"
218
  if page != 'N/A':
219
  citation += f" (Page {page})"
220
-
221
  citations.append(citation)
222
-
223
  return citations
224
 
225
 
 
4
  """
5
 
6
  from pathlib import Path
7
+
 
8
  from langchain_core.prompts import ChatPromptTemplate
9
 
10
+ from src.llm_config import llm_config
11
+ from src.state import AgentOutput, GuildState
12
+
13
 
14
  class DiseaseExplainerAgent:
15
  """Agent that retrieves and explains disease mechanisms using RAG"""
16
+
17
  def __init__(self, retriever):
18
  """
19
  Initialize with a retriever for medical PDFs.
 
23
  """
24
  self.retriever = retriever
25
  self.llm = llm_config.explainer
26
+
27
  def explain(self, state: GuildState) -> GuildState:
28
  """
29
  Retrieve and explain disease pathophysiology.
 
37
  print("\n" + "="*70)
38
  print("EXECUTING: Disease Explainer Agent (RAG)")
39
  print("="*70)
40
+
41
  model_prediction = state['model_prediction']
42
  disease = model_prediction['disease']
43
  confidence = model_prediction['confidence']
44
+
45
  # Configure retrieval based on SOP — create a copy to avoid mutating shared retriever
46
  retrieval_k = state['sop'].disease_explainer_k
47
  original_search_kwargs = dict(self.retriever.search_kwargs)
48
  self.retriever.search_kwargs = {**original_search_kwargs, 'k': retrieval_k}
49
+
50
  # Retrieve relevant documents
51
  print(f"\nRetrieving information about: {disease}")
52
  print(f"Retrieval k={state['sop'].disease_explainer_k}")
53
+
54
  query = f"""What is {disease}? Explain the pathophysiology, diagnostic criteria,
55
  and clinical presentation. Focus on mechanisms relevant to blood biomarkers."""
56
+
57
  try:
58
  docs = self.retriever.invoke(query)
59
  finally:
 
89
  print(" - Pathophysiology: insufficient evidence")
90
  print(" - Citations: 0 sources")
91
  return {'agent_outputs': [output]}
92
+
93
  # Generate explanation
94
  explanation = self._generate_explanation(disease, docs, confidence)
95
+
96
  # Extract citations
97
  citations = self._extract_citations(docs)
98
+
99
  # Create agent output
100
  output = AgentOutput(
101
  agent_name="Disease Explainer",
 
111
  "citations_missing": False
112
  }
113
  )
114
+
115
  # Update state
116
  print("\nDisease explanation generated")
117
  print(f" - Pathophysiology: {len(explanation['pathophysiology'])} chars")
118
  print(f" - Citations: {len(citations)} sources")
119
+
120
  return {'agent_outputs': [output]}
121
+
122
  def _generate_explanation(self, disease: str, docs: list, confidence: float) -> dict:
123
  """Generate structured disease explanation using LLM and retrieved docs"""
124
+
125
  # Format retrieved context
126
  context = "\n\n---\n\n".join([
127
  f"Source: {doc.metadata.get('source', 'Unknown')}\n\n{doc.page_content}"
128
  for doc in docs
129
  ])
130
+
131
  prompt = ChatPromptTemplate.from_messages([
132
  ("system", """You are a medical expert explaining diseases for patient self-assessment.
133
  Based on the provided medical literature, explain the disease in clear, accessible language.
 
146
 
147
  Please provide a structured explanation.""")
148
  ])
149
+
150
  chain = prompt | self.llm
151
+
152
  try:
153
  response = chain.invoke({
154
  "disease": disease,
155
  "confidence": confidence,
156
  "context": context
157
  })
158
+
159
  # Parse structured response
160
  content = response.content
161
  explanation = self._parse_explanation(content)
162
+
163
  except Exception as e:
164
  print(f"Warning: LLM explanation generation failed: {e}")
165
  explanation = {
 
168
  "clinical_presentation": "Clinical presentation varies by individual.",
169
  "summary": f"{disease} detected with {confidence:.1%} confidence. Consult healthcare provider."
170
  }
171
+
172
  return explanation
173
+
174
  def _parse_explanation(self, content: str) -> dict:
175
  """Parse LLM response into structured sections"""
176
  sections = {
 
179
  "clinical_presentation": "",
180
  "summary": ""
181
  }
182
+
183
  # Simple parsing logic
184
  current_section = None
185
  lines = content.split('\n')
186
+
187
  for line in lines:
188
  line_upper = line.upper().strip()
189
+
190
  if 'PATHOPHYSIOLOGY' in line_upper:
191
  current_section = 'pathophysiology'
192
  elif 'DIAGNOSTIC' in line_upper:
 
197
  current_section = 'summary'
198
  elif current_section and line.strip():
199
  sections[current_section] += line + "\n"
200
+
201
  # If parsing failed, use full content as summary
202
  if not any(sections.values()):
203
  sections['summary'] = content[:500]
204
+
205
  return sections
206
+
207
  def _extract_citations(self, docs: list) -> list:
208
  """Extract citations from retrieved documents"""
209
  citations = []
210
+
211
  for doc in docs:
212
  source = doc.metadata.get('source', 'Unknown')
213
  page = doc.metadata.get('page', 'N/A')
214
+
215
  # Clean up source path
216
  if '\\' in source or '/' in source:
217
  source = Path(source).name
218
+
219
  citation = f"{source}"
220
  if page != 'N/A':
221
  citation += f" (Page {page})"
222
+
223
  citations.append(citation)
224
+
225
  return citations
226
 
227
 
src/agents/response_synthesizer.py CHANGED
@@ -3,19 +3,20 @@ MediGuard AI RAG-Helper
3
  Response Synthesizer Agent - Compiles all findings into final structured JSON
4
  """
5
 
6
- import json
7
- from typing import Dict, List, Any
8
- from src.state import GuildState
9
- from src.llm_config import llm_config
10
  from langchain_core.prompts import ChatPromptTemplate
11
 
 
 
 
12
 
13
  class ResponseSynthesizerAgent:
14
  """Agent that synthesizes all specialist findings into the final response"""
15
-
16
  def __init__(self):
17
  self.llm = llm_config.get_synthesizer()
18
-
19
  def synthesize(self, state: GuildState) -> GuildState:
20
  """
21
  Synthesize all agent outputs into final response.
@@ -29,17 +30,17 @@ class ResponseSynthesizerAgent:
29
  print("\n" + "="*70)
30
  print("EXECUTING: Response Synthesizer Agent")
31
  print("="*70)
32
-
33
  model_prediction = state['model_prediction']
34
  patient_biomarkers = state['patient_biomarkers']
35
  patient_context = state.get('patient_context', {})
36
  agent_outputs = state.get('agent_outputs', [])
37
-
38
  # Collect findings from all agents
39
  findings = self._collect_findings(agent_outputs)
40
-
41
  print(f"\nSynthesizing findings from {len(agent_outputs)} specialist agents...")
42
-
43
  # Build structured response
44
  recs = self._build_recommendations(findings)
45
  response = {
@@ -64,38 +65,38 @@ class ResponseSynthesizerAgent:
64
  "alternative_diagnoses": self._build_alternative_diagnoses(findings)
65
  }
66
  }
67
-
68
  # Generate patient-friendly summary
69
  response["patient_summary"]["narrative"] = self._generate_narrative_summary(
70
  model_prediction,
71
  findings,
72
  response
73
  )
74
-
75
  print("\nResponse synthesis complete")
76
- print(f" - Patient summary: Generated")
77
  print(f" - Prediction explanation: {len(response['prediction_explanation']['key_drivers'])} key drivers")
78
  print(f" - Recommendations: {len(response['clinical_recommendations']['immediate_actions'])} immediate actions")
79
  print(f" - Safety alerts: {len(response['safety_alerts'])} alerts")
80
-
81
  return {'final_response': response}
82
-
83
- def _collect_findings(self, agent_outputs: List) -> Dict[str, Any]:
84
  """Organize all agent findings by agent name"""
85
  findings = {}
86
  for output in agent_outputs:
87
  findings[output.agent_name] = output.findings
88
  return findings
89
-
90
- def _build_patient_summary(self, biomarkers: Dict, findings: Dict) -> Dict:
91
  """Build patient summary section"""
92
  biomarker_analysis = findings.get("Biomarker Analyzer", {})
93
  flags = biomarker_analysis.get('biomarker_flags', [])
94
-
95
  # Count biomarker statuses
96
  critical = len([f for f in flags if 'CRITICAL' in f.get('status', '')])
97
  abnormal = len([f for f in flags if f.get('status') != 'NORMAL'])
98
-
99
  return {
100
  "total_biomarkers_tested": len(biomarkers),
101
  "biomarkers_in_normal_range": len(flags) - abnormal,
@@ -104,15 +105,15 @@ class ResponseSynthesizerAgent:
104
  "overall_risk_profile": biomarker_analysis.get('summary', 'Assessment complete'),
105
  "narrative": "" # Will be filled later
106
  }
107
-
108
- def _build_prediction_explanation(self, model_prediction: Dict, findings: Dict) -> Dict:
109
  """Build prediction explanation section"""
110
  disease_explanation = findings.get("Disease Explainer", {})
111
  linker_findings = findings.get("Biomarker-Disease Linker", {})
112
-
113
  disease = model_prediction['disease']
114
  confidence = model_prediction['confidence']
115
-
116
  # Get key drivers
117
  key_drivers_raw = linker_findings.get('key_drivers', [])
118
  key_drivers = [
@@ -125,7 +126,7 @@ class ResponseSynthesizerAgent:
125
  }
126
  for kd in key_drivers_raw
127
  ]
128
-
129
  return {
130
  "primary_disease": disease,
131
  "confidence": confidence,
@@ -135,37 +136,37 @@ class ResponseSynthesizerAgent:
135
  "pdf_references": disease_explanation.get('citations', [])
136
  }
137
 
138
- def _build_biomarker_flags(self, findings: Dict) -> List[Dict]:
139
  biomarker_analysis = findings.get("Biomarker Analyzer", {})
140
  return biomarker_analysis.get('biomarker_flags', [])
141
 
142
- def _build_key_drivers(self, findings: Dict) -> List[Dict]:
143
  linker_findings = findings.get("Biomarker-Disease Linker", {})
144
  return linker_findings.get('key_drivers', [])
145
 
146
- def _build_disease_explanation(self, findings: Dict) -> Dict:
147
  disease_explanation = findings.get("Disease Explainer", {})
148
  return {
149
  "pathophysiology": disease_explanation.get('pathophysiology', ''),
150
  "citations": disease_explanation.get('citations', []),
151
  "retrieved_chunks": disease_explanation.get('retrieved_chunks')
152
  }
153
-
154
- def _build_recommendations(self, findings: Dict) -> Dict:
155
  """Build clinical recommendations section"""
156
  guidelines = findings.get("Clinical Guidelines", {})
157
-
158
  return {
159
  "immediate_actions": guidelines.get('immediate_actions', []),
160
  "lifestyle_changes": guidelines.get('lifestyle_changes', []),
161
  "monitoring": guidelines.get('monitoring', []),
162
  "guideline_citations": guidelines.get('guideline_citations', [])
163
  }
164
-
165
- def _build_confidence_assessment(self, findings: Dict) -> Dict:
166
  """Build confidence assessment section"""
167
  assessment = findings.get("Confidence Assessor", {})
168
-
169
  return {
170
  "prediction_reliability": assessment.get('prediction_reliability', 'UNKNOWN'),
171
  "evidence_strength": assessment.get('evidence_strength', 'UNKNOWN'),
@@ -175,19 +176,19 @@ class ResponseSynthesizerAgent:
175
  "alternative_diagnoses": assessment.get('alternative_diagnoses', [])
176
  }
177
 
178
- def _build_alternative_diagnoses(self, findings: Dict) -> List[Dict]:
179
  assessment = findings.get("Confidence Assessor", {})
180
  return assessment.get('alternative_diagnoses', [])
181
-
182
- def _build_safety_alerts(self, findings: Dict) -> List[Dict]:
183
  """Build safety alerts section"""
184
  biomarker_analysis = findings.get("Biomarker Analyzer", {})
185
  return biomarker_analysis.get('safety_alerts', [])
186
-
187
- def _build_metadata(self, state: GuildState) -> Dict:
188
  """Build metadata section"""
189
  from datetime import datetime
190
-
191
  return {
192
  "timestamp": datetime.now().isoformat(),
193
  "system_version": "MediGuard AI RAG-Helper v1.0",
@@ -195,24 +196,24 @@ class ResponseSynthesizerAgent:
195
  "agents_executed": [output.agent_name for output in state.get('agent_outputs', [])],
196
  "disclaimer": "This is an AI-assisted analysis tool for patient self-assessment. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always consult qualified healthcare providers for medical decisions."
197
  }
198
-
199
  def _generate_narrative_summary(
200
  self,
201
  model_prediction,
202
- findings: Dict,
203
- response: Dict
204
  ) -> str:
205
  """Generate a patient-friendly narrative summary using LLM"""
206
-
207
  disease = model_prediction['disease']
208
  confidence = model_prediction['confidence']
209
  reliability = response['confidence_assessment']['prediction_reliability']
210
-
211
  # Get key points
212
  critical_count = response['patient_summary']['critical_values']
213
  abnormal_count = response['patient_summary']['biomarkers_out_of_range']
214
  key_drivers = response['prediction_explanation']['key_drivers']
215
-
216
  prompt = ChatPromptTemplate.from_messages([
217
  ("system", """You are a medical AI assistant explaining test results to a patient.
218
  Write a clear, compassionate 3-4 sentence summary that:
@@ -231,12 +232,12 @@ class ResponseSynthesizerAgent:
231
 
232
  Write a compassionate patient summary.""")
233
  ])
234
-
235
  chain = prompt | self.llm
236
-
237
  try:
238
  driver_names = [kd['biomarker'] for kd in key_drivers[:3]]
239
-
240
  response_obj = chain.invoke({
241
  "disease": disease,
242
  "confidence": confidence,
@@ -245,9 +246,9 @@ class ResponseSynthesizerAgent:
245
  "abnormal": abnormal_count,
246
  "drivers": ", ".join(driver_names) if driver_names else "Multiple biomarkers"
247
  })
248
-
249
  return response_obj.content.strip()
250
-
251
  except Exception as e:
252
  print(f"Warning: Narrative generation failed: {e}")
253
  return f"Your test results suggest {disease} with {confidence:.1%} confidence. {abnormal_count} biomarker(s) are out of normal range. Please consult with a healthcare provider for professional evaluation and guidance."
 
3
  Response Synthesizer Agent - Compiles all findings into final structured JSON
4
  """
5
 
6
+ from typing import Any
7
+
 
 
8
  from langchain_core.prompts import ChatPromptTemplate
9
 
10
+ from src.llm_config import llm_config
11
+ from src.state import GuildState
12
+
13
 
14
  class ResponseSynthesizerAgent:
15
  """Agent that synthesizes all specialist findings into the final response"""
16
+
17
  def __init__(self):
18
  self.llm = llm_config.get_synthesizer()
19
+
20
  def synthesize(self, state: GuildState) -> GuildState:
21
  """
22
  Synthesize all agent outputs into final response.
 
30
  print("\n" + "="*70)
31
  print("EXECUTING: Response Synthesizer Agent")
32
  print("="*70)
33
+
34
  model_prediction = state['model_prediction']
35
  patient_biomarkers = state['patient_biomarkers']
36
  patient_context = state.get('patient_context', {})
37
  agent_outputs = state.get('agent_outputs', [])
38
+
39
  # Collect findings from all agents
40
  findings = self._collect_findings(agent_outputs)
41
+
42
  print(f"\nSynthesizing findings from {len(agent_outputs)} specialist agents...")
43
+
44
  # Build structured response
45
  recs = self._build_recommendations(findings)
46
  response = {
 
65
  "alternative_diagnoses": self._build_alternative_diagnoses(findings)
66
  }
67
  }
68
+
69
  # Generate patient-friendly summary
70
  response["patient_summary"]["narrative"] = self._generate_narrative_summary(
71
  model_prediction,
72
  findings,
73
  response
74
  )
75
+
76
  print("\nResponse synthesis complete")
77
+ print(" - Patient summary: Generated")
78
  print(f" - Prediction explanation: {len(response['prediction_explanation']['key_drivers'])} key drivers")
79
  print(f" - Recommendations: {len(response['clinical_recommendations']['immediate_actions'])} immediate actions")
80
  print(f" - Safety alerts: {len(response['safety_alerts'])} alerts")
81
+
82
  return {'final_response': response}
83
+
84
+ def _collect_findings(self, agent_outputs: list) -> dict[str, Any]:
85
  """Organize all agent findings by agent name"""
86
  findings = {}
87
  for output in agent_outputs:
88
  findings[output.agent_name] = output.findings
89
  return findings
90
+
91
+ def _build_patient_summary(self, biomarkers: dict, findings: dict) -> dict:
92
  """Build patient summary section"""
93
  biomarker_analysis = findings.get("Biomarker Analyzer", {})
94
  flags = biomarker_analysis.get('biomarker_flags', [])
95
+
96
  # Count biomarker statuses
97
  critical = len([f for f in flags if 'CRITICAL' in f.get('status', '')])
98
  abnormal = len([f for f in flags if f.get('status') != 'NORMAL'])
99
+
100
  return {
101
  "total_biomarkers_tested": len(biomarkers),
102
  "biomarkers_in_normal_range": len(flags) - abnormal,
 
105
  "overall_risk_profile": biomarker_analysis.get('summary', 'Assessment complete'),
106
  "narrative": "" # Will be filled later
107
  }
108
+
109
+ def _build_prediction_explanation(self, model_prediction: dict, findings: dict) -> dict:
110
  """Build prediction explanation section"""
111
  disease_explanation = findings.get("Disease Explainer", {})
112
  linker_findings = findings.get("Biomarker-Disease Linker", {})
113
+
114
  disease = model_prediction['disease']
115
  confidence = model_prediction['confidence']
116
+
117
  # Get key drivers
118
  key_drivers_raw = linker_findings.get('key_drivers', [])
119
  key_drivers = [
 
126
  }
127
  for kd in key_drivers_raw
128
  ]
129
+
130
  return {
131
  "primary_disease": disease,
132
  "confidence": confidence,
 
136
  "pdf_references": disease_explanation.get('citations', [])
137
  }
138
 
139
+ def _build_biomarker_flags(self, findings: dict) -> list[dict]:
140
  biomarker_analysis = findings.get("Biomarker Analyzer", {})
141
  return biomarker_analysis.get('biomarker_flags', [])
142
 
143
+ def _build_key_drivers(self, findings: dict) -> list[dict]:
144
  linker_findings = findings.get("Biomarker-Disease Linker", {})
145
  return linker_findings.get('key_drivers', [])
146
 
147
+ def _build_disease_explanation(self, findings: dict) -> dict:
148
  disease_explanation = findings.get("Disease Explainer", {})
149
  return {
150
  "pathophysiology": disease_explanation.get('pathophysiology', ''),
151
  "citations": disease_explanation.get('citations', []),
152
  "retrieved_chunks": disease_explanation.get('retrieved_chunks')
153
  }
154
+
155
+ def _build_recommendations(self, findings: dict) -> dict:
156
  """Build clinical recommendations section"""
157
  guidelines = findings.get("Clinical Guidelines", {})
158
+
159
  return {
160
  "immediate_actions": guidelines.get('immediate_actions', []),
161
  "lifestyle_changes": guidelines.get('lifestyle_changes', []),
162
  "monitoring": guidelines.get('monitoring', []),
163
  "guideline_citations": guidelines.get('guideline_citations', [])
164
  }
165
+
166
+ def _build_confidence_assessment(self, findings: dict) -> dict:
167
  """Build confidence assessment section"""
168
  assessment = findings.get("Confidence Assessor", {})
169
+
170
  return {
171
  "prediction_reliability": assessment.get('prediction_reliability', 'UNKNOWN'),
172
  "evidence_strength": assessment.get('evidence_strength', 'UNKNOWN'),
 
176
  "alternative_diagnoses": assessment.get('alternative_diagnoses', [])
177
  }
178
 
179
+ def _build_alternative_diagnoses(self, findings: dict) -> list[dict]:
180
  assessment = findings.get("Confidence Assessor", {})
181
  return assessment.get('alternative_diagnoses', [])
182
+
183
+ def _build_safety_alerts(self, findings: dict) -> list[dict]:
184
  """Build safety alerts section"""
185
  biomarker_analysis = findings.get("Biomarker Analyzer", {})
186
  return biomarker_analysis.get('safety_alerts', [])
187
+
188
+ def _build_metadata(self, state: GuildState) -> dict:
189
  """Build metadata section"""
190
  from datetime import datetime
191
+
192
  return {
193
  "timestamp": datetime.now().isoformat(),
194
  "system_version": "MediGuard AI RAG-Helper v1.0",
 
196
  "agents_executed": [output.agent_name for output in state.get('agent_outputs', [])],
197
  "disclaimer": "This is an AI-assisted analysis tool for patient self-assessment. It is NOT a substitute for professional medical advice, diagnosis, or treatment. Always consult qualified healthcare providers for medical decisions."
198
  }
199
+
200
  def _generate_narrative_summary(
201
  self,
202
  model_prediction,
203
+ findings: dict,
204
+ response: dict
205
  ) -> str:
206
  """Generate a patient-friendly narrative summary using LLM"""
207
+
208
  disease = model_prediction['disease']
209
  confidence = model_prediction['confidence']
210
  reliability = response['confidence_assessment']['prediction_reliability']
211
+
212
  # Get key points
213
  critical_count = response['patient_summary']['critical_values']
214
  abnormal_count = response['patient_summary']['biomarkers_out_of_range']
215
  key_drivers = response['prediction_explanation']['key_drivers']
216
+
217
  prompt = ChatPromptTemplate.from_messages([
218
  ("system", """You are a medical AI assistant explaining test results to a patient.
219
  Write a clear, compassionate 3-4 sentence summary that:
 
232
 
233
  Write a compassionate patient summary.""")
234
  ])
235
+
236
  chain = prompt | self.llm
237
+
238
  try:
239
  driver_names = [kd['biomarker'] for kd in key_drivers[:3]]
240
+
241
  response_obj = chain.invoke({
242
  "disease": disease,
243
  "confidence": confidence,
 
246
  "abnormal": abnormal_count,
247
  "drivers": ", ".join(driver_names) if driver_names else "Multiple biomarkers"
248
  })
249
+
250
  return response_obj.content.strip()
251
+
252
  except Exception as e:
253
  print(f"Warning: Narrative generation failed: {e}")
254
  return f"Your test results suggest {disease} with {confidence:.1%} confidence. {abnormal_count} biomarker(s) are out of normal range. Please consult with a healthcare provider for professional evaluation and guidance."
src/biomarker_normalization.py CHANGED
@@ -3,10 +3,9 @@ MediGuard AI RAG-Helper
3
  Shared biomarker normalization utilities
4
  """
5
 
6
- from typing import Dict
7
 
8
  # Normalization map for biomarker aliases to canonical names.
9
- NORMALIZATION_MAP: Dict[str, str] = {
10
  # Glucose variations
11
  "glucose": "Glucose",
12
  "bloodsugar": "Glucose",
 
3
  Shared biomarker normalization utilities
4
  """
5
 
 
6
 
7
  # Normalization map for biomarker aliases to canonical names.
8
+ NORMALIZATION_MAP: dict[str, str] = {
9
  # Glucose variations
10
  "glucose": "Glucose",
11
  "bloodsugar": "Glucose",
src/biomarker_validator.py CHANGED
@@ -5,24 +5,24 @@ Biomarker analysis and validation utilities
5
 
6
  import json
7
  from pathlib import Path
8
- from typing import Dict, List, Tuple, Optional
9
  from src.state import BiomarkerFlag, SafetyAlert
10
 
11
 
12
  class BiomarkerValidator:
13
  """Validates biomarker values against reference ranges"""
14
-
15
  def __init__(self, reference_file: str = "config/biomarker_references.json"):
16
  """Load biomarker reference ranges from JSON file"""
17
  ref_path = Path(__file__).parent.parent / reference_file
18
- with open(ref_path, 'r') as f:
19
  self.references = json.load(f)['biomarkers']
20
-
21
  def validate_biomarker(
22
- self,
23
- name: str,
24
- value: float,
25
- gender: Optional[str] = None,
26
  threshold_pct: float = 0.0
27
  ) -> BiomarkerFlag:
28
  """
@@ -46,10 +46,10 @@ class BiomarkerValidator:
46
  reference_range="No reference data available",
47
  warning=f"No reference range found for {name}"
48
  )
49
-
50
  ref = self.references[name]
51
  unit = ref['unit']
52
-
53
  # Handle gender-specific ranges
54
  if ref.get('gender_specific', False) and gender:
55
  if gender.lower() in ['male', 'm']:
@@ -60,16 +60,16 @@ class BiomarkerValidator:
60
  normal = ref['normal_range']
61
  else:
62
  normal = ref['normal_range']
63
-
64
  min_val = normal.get('min', 0)
65
  max_val = normal.get('max', float('inf'))
66
  critical_low = ref.get('critical_low')
67
  critical_high = ref.get('critical_high')
68
-
69
  # Determine status
70
  status = "NORMAL"
71
  warning = None
72
-
73
  # Check critical values first (threshold_pct does not suppress critical alerts)
74
  if critical_low and value < critical_low:
75
  status = "CRITICAL_LOW"
@@ -88,9 +88,9 @@ class BiomarkerValidator:
88
  if deviation > threshold_pct:
89
  status = "HIGH"
90
  warning = f"{name} is {value} {unit}, above normal range ({min_val}-{max_val} {unit}). {ref['clinical_significance'].get('high', '')}"
91
-
92
  reference_range = f"{min_val}-{max_val} {unit}"
93
-
94
  return BiomarkerFlag(
95
  name=name,
96
  value=value,
@@ -99,13 +99,13 @@ class BiomarkerValidator:
99
  reference_range=reference_range,
100
  warning=warning
101
  )
102
-
103
  def validate_all(
104
  self,
105
- biomarkers: Dict[str, float],
106
- gender: Optional[str] = None,
107
  threshold_pct: float = 0.0
108
- ) -> Tuple[List[BiomarkerFlag], List[SafetyAlert]]:
109
  """
110
  Validate all biomarker values.
111
 
@@ -119,11 +119,11 @@ class BiomarkerValidator:
119
  """
120
  flags = []
121
  alerts = []
122
-
123
  for name, value in biomarkers.items():
124
  flag = self.validate_biomarker(name, value, gender, threshold_pct)
125
  flags.append(flag)
126
-
127
  # Generate safety alerts for critical values
128
  if flag.status in ["CRITICAL_LOW", "CRITICAL_HIGH"]:
129
  alerts.append(SafetyAlert(
@@ -140,18 +140,18 @@ class BiomarkerValidator:
140
  message=flag.warning or f"{name} out of normal range",
141
  action="Consult with healthcare provider"
142
  ))
143
-
144
  return flags, alerts
145
-
146
- def get_biomarker_info(self, name: str) -> Optional[Dict]:
147
  """Get reference information for a biomarker"""
148
  return self.references.get(name)
149
 
150
  def expected_biomarker_count(self) -> int:
151
  """Return expected number of biomarkers from reference ranges."""
152
  return len(self.references)
153
-
154
- def get_disease_relevant_biomarkers(self, disease: str) -> List[str]:
155
  """
156
  Get list of biomarkers most relevant to a specific disease.
157
 
@@ -159,19 +159,19 @@ class BiomarkerValidator:
159
  """
160
  disease_map = {
161
  "Diabetes": [
162
- "Glucose", "HbA1c", "Insulin", "BMI",
163
  "Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
164
  ],
165
  "Type 2 Diabetes": [
166
- "Glucose", "HbA1c", "Insulin", "BMI",
167
  "Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
168
  ],
169
  "Type 1 Diabetes": [
170
- "Glucose", "HbA1c", "Insulin", "BMI",
171
  "Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
172
  ],
173
  "Anemia": [
174
- "Hemoglobin", "Red Blood Cells", "Hematocrit",
175
  "Mean Corpuscular Volume", "Mean Corpuscular Hemoglobin",
176
  "Mean Corpuscular Hemoglobin Concentration"
177
  ],
@@ -189,5 +189,5 @@ class BiomarkerValidator:
189
  "Heart Rate", "BMI"
190
  ]
191
  }
192
-
193
  return disease_map.get(disease, [])
 
5
 
6
  import json
7
  from pathlib import Path
8
+
9
  from src.state import BiomarkerFlag, SafetyAlert
10
 
11
 
12
  class BiomarkerValidator:
13
  """Validates biomarker values against reference ranges"""
14
+
15
  def __init__(self, reference_file: str = "config/biomarker_references.json"):
16
  """Load biomarker reference ranges from JSON file"""
17
  ref_path = Path(__file__).parent.parent / reference_file
18
+ with open(ref_path) as f:
19
  self.references = json.load(f)['biomarkers']
20
+
21
  def validate_biomarker(
22
+ self,
23
+ name: str,
24
+ value: float,
25
+ gender: str | None = None,
26
  threshold_pct: float = 0.0
27
  ) -> BiomarkerFlag:
28
  """
 
46
  reference_range="No reference data available",
47
  warning=f"No reference range found for {name}"
48
  )
49
+
50
  ref = self.references[name]
51
  unit = ref['unit']
52
+
53
  # Handle gender-specific ranges
54
  if ref.get('gender_specific', False) and gender:
55
  if gender.lower() in ['male', 'm']:
 
60
  normal = ref['normal_range']
61
  else:
62
  normal = ref['normal_range']
63
+
64
  min_val = normal.get('min', 0)
65
  max_val = normal.get('max', float('inf'))
66
  critical_low = ref.get('critical_low')
67
  critical_high = ref.get('critical_high')
68
+
69
  # Determine status
70
  status = "NORMAL"
71
  warning = None
72
+
73
  # Check critical values first (threshold_pct does not suppress critical alerts)
74
  if critical_low and value < critical_low:
75
  status = "CRITICAL_LOW"
 
88
  if deviation > threshold_pct:
89
  status = "HIGH"
90
  warning = f"{name} is {value} {unit}, above normal range ({min_val}-{max_val} {unit}). {ref['clinical_significance'].get('high', '')}"
91
+
92
  reference_range = f"{min_val}-{max_val} {unit}"
93
+
94
  return BiomarkerFlag(
95
  name=name,
96
  value=value,
 
99
  reference_range=reference_range,
100
  warning=warning
101
  )
102
+
103
  def validate_all(
104
  self,
105
+ biomarkers: dict[str, float],
106
+ gender: str | None = None,
107
  threshold_pct: float = 0.0
108
+ ) -> tuple[list[BiomarkerFlag], list[SafetyAlert]]:
109
  """
110
  Validate all biomarker values.
111
 
 
119
  """
120
  flags = []
121
  alerts = []
122
+
123
  for name, value in biomarkers.items():
124
  flag = self.validate_biomarker(name, value, gender, threshold_pct)
125
  flags.append(flag)
126
+
127
  # Generate safety alerts for critical values
128
  if flag.status in ["CRITICAL_LOW", "CRITICAL_HIGH"]:
129
  alerts.append(SafetyAlert(
 
140
  message=flag.warning or f"{name} out of normal range",
141
  action="Consult with healthcare provider"
142
  ))
143
+
144
  return flags, alerts
145
+
146
+ def get_biomarker_info(self, name: str) -> dict | None:
147
  """Get reference information for a biomarker"""
148
  return self.references.get(name)
149
 
150
  def expected_biomarker_count(self) -> int:
151
  """Return expected number of biomarkers from reference ranges."""
152
  return len(self.references)
153
+
154
+ def get_disease_relevant_biomarkers(self, disease: str) -> list[str]:
155
  """
156
  Get list of biomarkers most relevant to a specific disease.
157
 
 
159
  """
160
  disease_map = {
161
  "Diabetes": [
162
+ "Glucose", "HbA1c", "Insulin", "BMI",
163
  "Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
164
  ],
165
  "Type 2 Diabetes": [
166
+ "Glucose", "HbA1c", "Insulin", "BMI",
167
  "Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
168
  ],
169
  "Type 1 Diabetes": [
170
+ "Glucose", "HbA1c", "Insulin", "BMI",
171
  "Triglycerides", "HDL Cholesterol", "LDL Cholesterol"
172
  ],
173
  "Anemia": [
174
+ "Hemoglobin", "Red Blood Cells", "Hematocrit",
175
  "Mean Corpuscular Volume", "Mean Corpuscular Hemoglobin",
176
  "Mean Corpuscular Hemoglobin Concentration"
177
  ],
 
189
  "Heart Rate", "BMI"
190
  ]
191
  }
192
+
193
  return disease_map.get(disease, [])
src/config.py CHANGED
@@ -3,8 +3,9 @@ MediGuard AI RAG-Helper
3
  Core configuration and SOP (Standard Operating Procedures) definitions
4
  """
5
 
 
 
6
  from pydantic import BaseModel, Field
7
- from typing import Literal, Dict, Any, List, Optional
8
 
9
 
10
  class ExplanationSOP(BaseModel):
@@ -13,28 +14,28 @@ class ExplanationSOP(BaseModel):
13
  This is the 'genome' that controls the entire RAG pipeline behavior.
14
  The Outer Loop (Director) will evolve these parameters to improve performance.
15
  """
16
-
17
  # === Agent Behavior Parameters ===
18
  biomarker_analyzer_threshold: float = Field(
19
  default=0.15,
20
  description="Percentage deviation from normal range to trigger a warning flag (0.15 = 15%)"
21
  )
22
-
23
  disease_explainer_k: int = Field(
24
  default=5,
25
  description="Number of top PDF chunks to retrieve for disease explanation"
26
  )
27
-
28
  linker_retrieval_k: int = Field(
29
  default=3,
30
  description="Number of chunks for biomarker-disease linking"
31
  )
32
-
33
  guideline_retrieval_k: int = Field(
34
  default=3,
35
  description="Number of chunks for clinical guidelines"
36
  )
37
-
38
  # === Prompts (Evolvable) ===
39
  planner_prompt: str = Field(
40
  default="""You are a medical AI coordinator. Create a structured execution plan for analyzing patient biomarkers and explaining a disease prediction.
@@ -49,7 +50,7 @@ Available specialist agents:
49
  Output a JSON with key 'plan' containing a list of tasks. Each task must have 'agent', 'task_description', and 'dependencies' keys.""",
50
  description="System prompt for the Planner Agent"
51
  )
52
-
53
  synthesizer_prompt: str = Field(
54
  default="""You are a medical communication specialist. Your task is to synthesize findings from specialist agents into a clear, patient-friendly clinical explanation.
55
 
@@ -64,39 +65,39 @@ Output a JSON with key 'plan' containing a list of tasks. Each task must have 'a
64
  Structure your output as specified in the output schema.""",
65
  description="System prompt for the Response Synthesizer"
66
  )
67
-
68
  explainer_detail_level: Literal["concise", "detailed", "comprehensive"] = Field(
69
  default="detailed",
70
  description="Level of detail in disease mechanism explanations"
71
  )
72
-
73
  # === Feature Flags ===
74
  use_guideline_agent: bool = Field(
75
  default=True,
76
  description="Whether to retrieve clinical guidelines and recommendations"
77
  )
78
-
79
  include_alternative_diagnoses: bool = Field(
80
  default=True,
81
  description="Whether to discuss alternative diagnoses from prediction probabilities"
82
  )
83
-
84
  require_pdf_citations: bool = Field(
85
  default=True,
86
  description="Whether to require PDF citations for all claims"
87
  )
88
-
89
  use_confidence_assessor: bool = Field(
90
  default=True,
91
  description="Whether to evaluate and report prediction confidence"
92
  )
93
-
94
  # === Safety Settings ===
95
  critical_value_alert_mode: Literal["strict", "moderate", "permissive"] = Field(
96
  default="strict",
97
  description="Threshold for critical value alerts"
98
  )
99
-
100
  # === Model Selection ===
101
  synthesizer_model: str = Field(
102
  default="default",
 
3
  Core configuration and SOP (Standard Operating Procedures) definitions
4
  """
5
 
6
+ from typing import Literal
7
+
8
  from pydantic import BaseModel, Field
 
9
 
10
 
11
  class ExplanationSOP(BaseModel):
 
14
  This is the 'genome' that controls the entire RAG pipeline behavior.
15
  The Outer Loop (Director) will evolve these parameters to improve performance.
16
  """
17
+
18
  # === Agent Behavior Parameters ===
19
  biomarker_analyzer_threshold: float = Field(
20
  default=0.15,
21
  description="Percentage deviation from normal range to trigger a warning flag (0.15 = 15%)"
22
  )
23
+
24
  disease_explainer_k: int = Field(
25
  default=5,
26
  description="Number of top PDF chunks to retrieve for disease explanation"
27
  )
28
+
29
  linker_retrieval_k: int = Field(
30
  default=3,
31
  description="Number of chunks for biomarker-disease linking"
32
  )
33
+
34
  guideline_retrieval_k: int = Field(
35
  default=3,
36
  description="Number of chunks for clinical guidelines"
37
  )
38
+
39
  # === Prompts (Evolvable) ===
40
  planner_prompt: str = Field(
41
  default="""You are a medical AI coordinator. Create a structured execution plan for analyzing patient biomarkers and explaining a disease prediction.
 
50
  Output a JSON with key 'plan' containing a list of tasks. Each task must have 'agent', 'task_description', and 'dependencies' keys.""",
51
  description="System prompt for the Planner Agent"
52
  )
53
+
54
  synthesizer_prompt: str = Field(
55
  default="""You are a medical communication specialist. Your task is to synthesize findings from specialist agents into a clear, patient-friendly clinical explanation.
56
 
 
65
  Structure your output as specified in the output schema.""",
66
  description="System prompt for the Response Synthesizer"
67
  )
68
+
69
  explainer_detail_level: Literal["concise", "detailed", "comprehensive"] = Field(
70
  default="detailed",
71
  description="Level of detail in disease mechanism explanations"
72
  )
73
+
74
  # === Feature Flags ===
75
  use_guideline_agent: bool = Field(
76
  default=True,
77
  description="Whether to retrieve clinical guidelines and recommendations"
78
  )
79
+
80
  include_alternative_diagnoses: bool = Field(
81
  default=True,
82
  description="Whether to discuss alternative diagnoses from prediction probabilities"
83
  )
84
+
85
  require_pdf_citations: bool = Field(
86
  default=True,
87
  description="Whether to require PDF citations for all claims"
88
  )
89
+
90
  use_confidence_assessor: bool = Field(
91
  default=True,
92
  description="Whether to evaluate and report prediction confidence"
93
  )
94
+
95
  # === Safety Settings ===
96
  critical_value_alert_mode: Literal["strict", "moderate", "permissive"] = Field(
97
  default="strict",
98
  description="Threshold for critical value alerts"
99
  )
100
+
101
  # === Model Selection ===
102
  synthesizer_model: str = Field(
103
  default="default",
src/database.py CHANGED
@@ -6,11 +6,11 @@ Provides SQLAlchemy engine/session factories and the declarative Base.
6
 
7
  from __future__ import annotations
8
 
 
9
  from functools import lru_cache
10
- from typing import Generator
11
 
12
  from sqlalchemy import create_engine
13
- from sqlalchemy.orm import Session, sessionmaker, DeclarativeBase
14
 
15
  from src.settings import get_settings
16
 
 
6
 
7
  from __future__ import annotations
8
 
9
+ from collections.abc import Generator
10
  from functools import lru_cache
 
11
 
12
  from sqlalchemy import create_engine
13
+ from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
14
 
15
  from src.settings import get_settings
16
 
src/dependencies.py CHANGED
@@ -6,9 +6,6 @@ Provides factory functions and ``Depends()`` for services used across routers.
6
 
7
  from __future__ import annotations
8
 
9
- from functools import lru_cache
10
-
11
- from src.settings import Settings, get_settings
12
  from src.services.cache.redis_cache import RedisCache, make_redis_cache
13
  from src.services.embeddings.service import EmbeddingService, make_embedding_service
14
  from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer
 
6
 
7
  from __future__ import annotations
8
 
 
 
 
9
  from src.services.cache.redis_cache import RedisCache, make_redis_cache
10
  from src.services.embeddings.service import EmbeddingService, make_embedding_service
11
  from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer
src/evaluation/__init__.py CHANGED
@@ -4,23 +4,23 @@ Exports 5D quality assessment framework components
4
  """
5
 
6
  from .evaluators import (
7
- GradedScore,
8
  EvaluationResult,
9
- evaluate_clinical_accuracy,
10
- evaluate_evidence_grounding,
11
  evaluate_actionability,
12
  evaluate_clarity,
 
 
13
  evaluate_safety_completeness,
14
- run_full_evaluation
15
  )
16
 
17
  __all__ = [
18
- 'GradedScore',
19
  'EvaluationResult',
20
- 'evaluate_clinical_accuracy',
21
- 'evaluate_evidence_grounding',
22
  'evaluate_actionability',
23
  'evaluate_clarity',
 
 
24
  'evaluate_safety_completeness',
25
  'run_full_evaluation'
26
  ]
 
4
  """
5
 
6
  from .evaluators import (
 
7
  EvaluationResult,
8
+ GradedScore,
 
9
  evaluate_actionability,
10
  evaluate_clarity,
11
+ evaluate_clinical_accuracy,
12
+ evaluate_evidence_grounding,
13
  evaluate_safety_completeness,
14
+ run_full_evaluation,
15
  )
16
 
17
  __all__ = [
 
18
  'EvaluationResult',
19
+ 'GradedScore',
 
20
  'evaluate_actionability',
21
  'evaluate_clarity',
22
+ 'evaluate_clinical_accuracy',
23
+ 'evaluate_evidence_grounding',
24
  'evaluate_safety_completeness',
25
  'run_full_evaluation'
26
  ]
src/evaluation/evaluators.py CHANGED
@@ -22,11 +22,13 @@ Usage:
22
  print(f"Average score: {result.average_score():.2f}")
23
  """
24
 
25
- import os
26
- from pydantic import BaseModel, Field
27
- from typing import Dict, Any, List
28
  import json
 
 
 
29
  from langchain_core.prompts import ChatPromptTemplate
 
 
30
  from src.llm_config import get_chat_model
31
 
32
  # Set to True for deterministic evaluation (testing)
@@ -46,8 +48,8 @@ class EvaluationResult(BaseModel):
46
  actionability: GradedScore
47
  clarity: GradedScore
48
  safety_completeness: GradedScore
49
-
50
- def to_vector(self) -> List[float]:
51
  """Extract scores as a vector for Pareto analysis"""
52
  return [
53
  self.clinical_accuracy.score,
@@ -56,7 +58,7 @@ class EvaluationResult(BaseModel):
56
  self.clarity.score,
57
  self.safety_completeness.score
58
  ]
59
-
60
  def average_score(self) -> float:
61
  """Calculate average of all 5 dimensions"""
62
  scores = self.to_vector()
@@ -65,7 +67,7 @@ class EvaluationResult(BaseModel):
65
 
66
  # Evaluator 1: Clinical Accuracy (LLM-as-Judge)
67
  def evaluate_clinical_accuracy(
68
- final_response: Dict[str, Any],
69
  pubmed_context: str
70
  ) -> GradedScore:
71
  """
@@ -77,13 +79,13 @@ def evaluate_clinical_accuracy(
77
  # Deterministic mode for testing
78
  if DETERMINISTIC_MODE:
79
  return _deterministic_clinical_accuracy(final_response, pubmed_context)
80
-
81
  # Use cloud LLM for evaluation (FREE via Groq/Gemini)
82
  evaluator_llm = get_chat_model(
83
  temperature=0.0,
84
  json_mode=True
85
  )
86
-
87
  prompt = ChatPromptTemplate.from_messages([
88
  ("system", """You are a medical expert evaluating clinical accuracy.
89
 
@@ -113,7 +115,7 @@ Respond ONLY with valid JSON in this format:
113
  {context}
114
  """)
115
  ])
116
-
117
  chain = prompt | evaluator_llm
118
  result = chain.invoke({
119
  "patient_summary": final_response['patient_summary'],
@@ -121,7 +123,7 @@ Respond ONLY with valid JSON in this format:
121
  "recommendations": final_response['clinical_recommendations'],
122
  "context": pubmed_context
123
  })
124
-
125
  # Parse JSON response
126
  try:
127
  content = result.content if isinstance(result.content, str) else str(result.content)
@@ -134,7 +136,7 @@ Respond ONLY with valid JSON in this format:
134
 
135
  # Evaluator 2: Evidence Grounding (Programmatic + LLM)
136
  def evaluate_evidence_grounding(
137
- final_response: Dict[str, Any]
138
  ) -> GradedScore:
139
  """
140
  Checks if all claims are backed by citations.
@@ -143,32 +145,32 @@ def evaluate_evidence_grounding(
143
  # Count citations
144
  pdf_refs = final_response['prediction_explanation'].get('pdf_references', [])
145
  citation_count = len(pdf_refs)
146
-
147
  # Check key drivers have evidence
148
  key_drivers = final_response['prediction_explanation'].get('key_drivers', [])
149
  drivers_with_evidence = sum(1 for d in key_drivers if d.get('evidence'))
150
-
151
  # Citation coverage score
152
  if len(key_drivers) > 0:
153
  coverage = drivers_with_evidence / len(key_drivers)
154
  else:
155
  coverage = 0.0
156
-
157
  # Base score from programmatic checks
158
  base_score = min(1.0, citation_count / 5.0) * 0.5 + coverage * 0.5
159
-
160
  reasoning = f"""
161
  Citations found: {citation_count}
162
  Key drivers with evidence: {drivers_with_evidence}/{len(key_drivers)}
163
  Citation coverage: {coverage:.1%}
164
  """
165
-
166
  return GradedScore(score=base_score, reasoning=reasoning.strip())
167
 
168
 
169
  # Evaluator 3: Clinical Actionability (LLM-as-Judge)
170
  def evaluate_actionability(
171
- final_response: Dict[str, Any]
172
  ) -> GradedScore:
173
  """
174
  Evaluates if recommendations are actionable and safe.
@@ -179,13 +181,13 @@ def evaluate_actionability(
179
  # Deterministic mode for testing
180
  if DETERMINISTIC_MODE:
181
  return _deterministic_actionability(final_response)
182
-
183
  # Use cloud LLM for evaluation (FREE via Groq/Gemini)
184
  evaluator_llm = get_chat_model(
185
  temperature=0.0,
186
  json_mode=True
187
  )
188
-
189
  prompt = ChatPromptTemplate.from_messages([
190
  ("system", """You are a clinical care coordinator evaluating actionability.
191
 
@@ -216,7 +218,7 @@ Respond ONLY with valid JSON in this format:
216
  {confidence}
217
  """)
218
  ])
219
-
220
  chain = prompt | evaluator_llm
221
  recs = final_response['clinical_recommendations']
222
  result = chain.invoke({
@@ -225,7 +227,7 @@ Respond ONLY with valid JSON in this format:
225
  "monitoring": recs.get('monitoring', []),
226
  "confidence": final_response['confidence_assessment']
227
  })
228
-
229
  # Parse JSON response
230
  try:
231
  parsed = json.loads(result.content if isinstance(result.content, str) else str(result.content))
@@ -237,7 +239,7 @@ Respond ONLY with valid JSON in this format:
237
 
238
  # Evaluator 4: Explainability Clarity (Programmatic)
239
  def evaluate_clarity(
240
- final_response: Dict[str, Any]
241
  ) -> GradedScore:
242
  """
243
  Measures readability and patient-friendliness.
@@ -248,16 +250,16 @@ def evaluate_clarity(
248
  # Deterministic mode for testing
249
  if DETERMINISTIC_MODE:
250
  return _deterministic_clarity(final_response)
251
-
252
  try:
253
  import textstat
254
  has_textstat = True
255
  except ImportError:
256
  has_textstat = False
257
-
258
  # Get patient narrative
259
  narrative = final_response['patient_summary'].get('narrative', '')
260
-
261
  if has_textstat:
262
  # Calculate readability (Flesch Reading Ease)
263
  # Score 60-70 = Standard (8th-9th grade)
@@ -275,24 +277,24 @@ def evaluate_clarity(
275
  readability_score = 0.9
276
  else:
277
  readability_score = max(0.5, 1.0 - (avg_words - 20) * 0.02)
278
-
279
  # Medical jargon detection (simple heuristic)
280
  medical_terms = [
281
  'pathophysiology', 'etiology', 'hemostasis', 'coagulation',
282
  'thrombocytopenia', 'erythropoiesis', 'gluconeogenesis'
283
  ]
284
  jargon_count = sum(1 for term in medical_terms if term.lower() in narrative.lower())
285
-
286
  # Length check (too short = vague, too long = overwhelming)
287
  word_count = len(narrative.split())
288
  optimal_length = 50 <= word_count <= 150
289
-
290
  # Scoring
291
  jargon_penalty = max(0.0, 1.0 - (jargon_count * 0.2))
292
  length_score = 1.0 if optimal_length else 0.7
293
-
294
  final_score = (readability_score * 0.5 + jargon_penalty * 0.3 + length_score * 0.2)
295
-
296
  if has_textstat:
297
  reasoning = f"""
298
  Flesch Reading Ease: {flesch_score:.1f} (Target: 60-70)
@@ -307,63 +309,63 @@ def evaluate_clarity(
307
  Word count: {word_count} (Optimal: 50-150)
308
  Note: textstat not available, using fallback metrics
309
  """
310
-
311
  return GradedScore(score=final_score, reasoning=reasoning.strip())
312
 
313
 
314
  # Evaluator 5: Safety & Completeness (Programmatic)
315
  def evaluate_safety_completeness(
316
- final_response: Dict[str, Any],
317
- biomarkers: Dict[str, float]
318
  ) -> GradedScore:
319
  """
320
  Checks if all safety concerns are flagged.
321
  Programmatic validation.
322
  """
323
  from src.biomarker_validator import BiomarkerValidator
324
-
325
  # Initialize validator
326
  validator = BiomarkerValidator()
327
-
328
  # Count out-of-range biomarkers
329
  out_of_range_count = 0
330
  critical_count = 0
331
-
332
  for name, value in biomarkers.items():
333
  result = validator.validate_biomarker(name, value) # Fixed: use validate_biomarker instead of validate_single
334
  if result.status in ['HIGH', 'LOW', 'CRITICAL_HIGH', 'CRITICAL_LOW']:
335
  out_of_range_count += 1
336
  if result.status in ['CRITICAL_HIGH', 'CRITICAL_LOW']:
337
  critical_count += 1
338
-
339
  # Count safety alerts in output
340
  safety_alerts = final_response.get('safety_alerts', [])
341
  alert_count = len(safety_alerts)
342
  critical_alerts = sum(1 for a in safety_alerts if a.get('severity') == 'CRITICAL')
343
-
344
  # Check if all critical values have alerts
345
  critical_coverage = critical_alerts / critical_count if critical_count > 0 else 1.0
346
-
347
  # Check for disclaimer
348
  has_disclaimer = 'disclaimer' in final_response.get('metadata', {})
349
-
350
  # Check for uncertainty acknowledgment
351
  limitations = final_response['confidence_assessment'].get('limitations', [])
352
  acknowledges_uncertainty = len(limitations) > 0
353
-
354
  # Scoring
355
  alert_score = min(1.0, alert_count / max(1, out_of_range_count))
356
  critical_score = min(1.0, critical_coverage)
357
  disclaimer_score = 1.0 if has_disclaimer else 0.0
358
  uncertainty_score = 1.0 if acknowledges_uncertainty else 0.5
359
-
360
  final_score = min(1.0, (
361
  alert_score * 0.4 +
362
  critical_score * 0.3 +
363
  disclaimer_score * 0.2 +
364
  uncertainty_score * 0.1
365
  ))
366
-
367
  reasoning = f"""
368
  Out-of-range biomarkers: {out_of_range_count}
369
  Critical values: {critical_count}
@@ -373,15 +375,15 @@ def evaluate_safety_completeness(
373
  Has disclaimer: {has_disclaimer}
374
  Acknowledges uncertainty: {acknowledges_uncertainty}
375
  """
376
-
377
  return GradedScore(score=final_score, reasoning=reasoning.strip())
378
 
379
 
380
  # Master Evaluation Function
381
  def run_full_evaluation(
382
- final_response: Dict[str, Any],
383
- agent_outputs: List[Any],
384
- biomarkers: Dict[str, float]
385
  ) -> EvaluationResult:
386
  """
387
  Orchestrates all 5 evaluators and returns complete assessment.
@@ -389,7 +391,7 @@ def run_full_evaluation(
389
  print("=" * 70)
390
  print("RUNNING 5D EVALUATION GAUNTLET")
391
  print("=" * 70)
392
-
393
  # Extract context from agent outputs
394
  pubmed_context = ""
395
  for output in agent_outputs:
@@ -402,27 +404,27 @@ def run_full_evaluation(
402
  else:
403
  pubmed_context = str(findings)
404
  break
405
-
406
  # Run all evaluators
407
  print("\n1. Evaluating Clinical Accuracy...")
408
  clinical_accuracy = evaluate_clinical_accuracy(final_response, pubmed_context)
409
-
410
  print("2. Evaluating Evidence Grounding...")
411
  evidence_grounding = evaluate_evidence_grounding(final_response)
412
-
413
  print("3. Evaluating Clinical Actionability...")
414
  actionability = evaluate_actionability(final_response)
415
-
416
  print("4. Evaluating Explainability Clarity...")
417
  clarity = evaluate_clarity(final_response)
418
-
419
  print("5. Evaluating Safety & Completeness...")
420
  safety_completeness = evaluate_safety_completeness(final_response, biomarkers)
421
-
422
  print("\n" + "=" * 70)
423
  print("EVALUATION COMPLETE")
424
  print("=" * 70)
425
-
426
  return EvaluationResult(
427
  clinical_accuracy=clinical_accuracy,
428
  evidence_grounding=evidence_grounding,
@@ -437,26 +439,26 @@ def run_full_evaluation(
437
  # ---------------------------------------------------------------------------
438
 
439
  def _deterministic_clinical_accuracy(
440
- final_response: Dict[str, Any],
441
  pubmed_context: str
442
  ) -> GradedScore:
443
  """Heuristic-based clinical accuracy (deterministic)."""
444
  score = 0.5
445
  reasons = []
446
-
447
  # Check if response has expected structure
448
  if final_response.get('patient_summary'):
449
  score += 0.1
450
  reasons.append("Has patient summary")
451
-
452
  if final_response.get('prediction_explanation'):
453
  score += 0.1
454
  reasons.append("Has prediction explanation")
455
-
456
  if final_response.get('clinical_recommendations'):
457
  score += 0.1
458
  reasons.append("Has clinical recommendations")
459
-
460
  # Check for citations
461
  pred = final_response.get('prediction_explanation', {})
462
  if isinstance(pred, dict):
@@ -464,7 +466,7 @@ def _deterministic_clinical_accuracy(
464
  if refs:
465
  score += min(0.2, len(refs) * 0.05)
466
  reasons.append(f"Has {len(refs)} citations")
467
-
468
  return GradedScore(
469
  score=min(1.0, score),
470
  reasoning="[DETERMINISTIC] " + "; ".join(reasons)
@@ -472,12 +474,12 @@ def _deterministic_clinical_accuracy(
472
 
473
 
474
  def _deterministic_actionability(
475
- final_response: Dict[str, Any]
476
  ) -> GradedScore:
477
  """Heuristic-based actionability (deterministic)."""
478
  score = 0.5
479
  reasons = []
480
-
481
  recs = final_response.get('clinical_recommendations', {})
482
  if isinstance(recs, dict):
483
  if recs.get('immediate_actions'):
@@ -489,7 +491,7 @@ def _deterministic_actionability(
489
  if recs.get('monitoring'):
490
  score += 0.1
491
  reasons.append("Has monitoring recommendations")
492
-
493
  return GradedScore(
494
  score=min(1.0, score),
495
  reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Missing recommendations"
@@ -497,12 +499,12 @@ def _deterministic_actionability(
497
 
498
 
499
  def _deterministic_clarity(
500
- final_response: Dict[str, Any]
501
  ) -> GradedScore:
502
  """Heuristic-based clarity (deterministic)."""
503
  score = 0.5
504
  reasons = []
505
-
506
  summary = final_response.get('patient_summary', '')
507
  if isinstance(summary, str):
508
  word_count = len(summary.split())
@@ -512,16 +514,16 @@ def _deterministic_clarity(
512
  elif word_count > 0:
513
  score += 0.1
514
  reasons.append("Has summary")
515
-
516
  # Check for structured output
517
  if final_response.get('biomarker_flags'):
518
  score += 0.15
519
  reasons.append("Has biomarker flags")
520
-
521
  if final_response.get('key_findings'):
522
  score += 0.15
523
  reasons.append("Has key findings")
524
-
525
  return GradedScore(
526
  score=min(1.0, score),
527
  reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Limited structure"
 
22
  print(f"Average score: {result.average_score():.2f}")
23
  """
24
 
 
 
 
25
  import json
26
+ import os
27
+ from typing import Any
28
+
29
  from langchain_core.prompts import ChatPromptTemplate
30
+ from pydantic import BaseModel, Field
31
+
32
  from src.llm_config import get_chat_model
33
 
34
  # Set to True for deterministic evaluation (testing)
 
48
  actionability: GradedScore
49
  clarity: GradedScore
50
  safety_completeness: GradedScore
51
+
52
+ def to_vector(self) -> list[float]:
53
  """Extract scores as a vector for Pareto analysis"""
54
  return [
55
  self.clinical_accuracy.score,
 
58
  self.clarity.score,
59
  self.safety_completeness.score
60
  ]
61
+
62
  def average_score(self) -> float:
63
  """Calculate average of all 5 dimensions"""
64
  scores = self.to_vector()
 
67
 
68
  # Evaluator 1: Clinical Accuracy (LLM-as-Judge)
69
  def evaluate_clinical_accuracy(
70
+ final_response: dict[str, Any],
71
  pubmed_context: str
72
  ) -> GradedScore:
73
  """
 
79
  # Deterministic mode for testing
80
  if DETERMINISTIC_MODE:
81
  return _deterministic_clinical_accuracy(final_response, pubmed_context)
82
+
83
  # Use cloud LLM for evaluation (FREE via Groq/Gemini)
84
  evaluator_llm = get_chat_model(
85
  temperature=0.0,
86
  json_mode=True
87
  )
88
+
89
  prompt = ChatPromptTemplate.from_messages([
90
  ("system", """You are a medical expert evaluating clinical accuracy.
91
 
 
115
  {context}
116
  """)
117
  ])
118
+
119
  chain = prompt | evaluator_llm
120
  result = chain.invoke({
121
  "patient_summary": final_response['patient_summary'],
 
123
  "recommendations": final_response['clinical_recommendations'],
124
  "context": pubmed_context
125
  })
126
+
127
  # Parse JSON response
128
  try:
129
  content = result.content if isinstance(result.content, str) else str(result.content)
 
136
 
137
  # Evaluator 2: Evidence Grounding (Programmatic + LLM)
138
  def evaluate_evidence_grounding(
139
+ final_response: dict[str, Any]
140
  ) -> GradedScore:
141
  """
142
  Checks if all claims are backed by citations.
 
145
  # Count citations
146
  pdf_refs = final_response['prediction_explanation'].get('pdf_references', [])
147
  citation_count = len(pdf_refs)
148
+
149
  # Check key drivers have evidence
150
  key_drivers = final_response['prediction_explanation'].get('key_drivers', [])
151
  drivers_with_evidence = sum(1 for d in key_drivers if d.get('evidence'))
152
+
153
  # Citation coverage score
154
  if len(key_drivers) > 0:
155
  coverage = drivers_with_evidence / len(key_drivers)
156
  else:
157
  coverage = 0.0
158
+
159
  # Base score from programmatic checks
160
  base_score = min(1.0, citation_count / 5.0) * 0.5 + coverage * 0.5
161
+
162
  reasoning = f"""
163
  Citations found: {citation_count}
164
  Key drivers with evidence: {drivers_with_evidence}/{len(key_drivers)}
165
  Citation coverage: {coverage:.1%}
166
  """
167
+
168
  return GradedScore(score=base_score, reasoning=reasoning.strip())
169
 
170
 
171
  # Evaluator 3: Clinical Actionability (LLM-as-Judge)
172
  def evaluate_actionability(
173
+ final_response: dict[str, Any]
174
  ) -> GradedScore:
175
  """
176
  Evaluates if recommendations are actionable and safe.
 
181
  # Deterministic mode for testing
182
  if DETERMINISTIC_MODE:
183
  return _deterministic_actionability(final_response)
184
+
185
  # Use cloud LLM for evaluation (FREE via Groq/Gemini)
186
  evaluator_llm = get_chat_model(
187
  temperature=0.0,
188
  json_mode=True
189
  )
190
+
191
  prompt = ChatPromptTemplate.from_messages([
192
  ("system", """You are a clinical care coordinator evaluating actionability.
193
 
 
218
  {confidence}
219
  """)
220
  ])
221
+
222
  chain = prompt | evaluator_llm
223
  recs = final_response['clinical_recommendations']
224
  result = chain.invoke({
 
227
  "monitoring": recs.get('monitoring', []),
228
  "confidence": final_response['confidence_assessment']
229
  })
230
+
231
  # Parse JSON response
232
  try:
233
  parsed = json.loads(result.content if isinstance(result.content, str) else str(result.content))
 
239
 
240
  # Evaluator 4: Explainability Clarity (Programmatic)
241
  def evaluate_clarity(
242
+ final_response: dict[str, Any]
243
  ) -> GradedScore:
244
  """
245
  Measures readability and patient-friendliness.
 
250
  # Deterministic mode for testing
251
  if DETERMINISTIC_MODE:
252
  return _deterministic_clarity(final_response)
253
+
254
  try:
255
  import textstat
256
  has_textstat = True
257
  except ImportError:
258
  has_textstat = False
259
+
260
  # Get patient narrative
261
  narrative = final_response['patient_summary'].get('narrative', '')
262
+
263
  if has_textstat:
264
  # Calculate readability (Flesch Reading Ease)
265
  # Score 60-70 = Standard (8th-9th grade)
 
277
  readability_score = 0.9
278
  else:
279
  readability_score = max(0.5, 1.0 - (avg_words - 20) * 0.02)
280
+
281
  # Medical jargon detection (simple heuristic)
282
  medical_terms = [
283
  'pathophysiology', 'etiology', 'hemostasis', 'coagulation',
284
  'thrombocytopenia', 'erythropoiesis', 'gluconeogenesis'
285
  ]
286
  jargon_count = sum(1 for term in medical_terms if term.lower() in narrative.lower())
287
+
288
  # Length check (too short = vague, too long = overwhelming)
289
  word_count = len(narrative.split())
290
  optimal_length = 50 <= word_count <= 150
291
+
292
  # Scoring
293
  jargon_penalty = max(0.0, 1.0 - (jargon_count * 0.2))
294
  length_score = 1.0 if optimal_length else 0.7
295
+
296
  final_score = (readability_score * 0.5 + jargon_penalty * 0.3 + length_score * 0.2)
297
+
298
  if has_textstat:
299
  reasoning = f"""
300
  Flesch Reading Ease: {flesch_score:.1f} (Target: 60-70)
 
309
  Word count: {word_count} (Optimal: 50-150)
310
  Note: textstat not available, using fallback metrics
311
  """
312
+
313
  return GradedScore(score=final_score, reasoning=reasoning.strip())
314
 
315
 
316
  # Evaluator 5: Safety & Completeness (Programmatic)
317
  def evaluate_safety_completeness(
318
+ final_response: dict[str, Any],
319
+ biomarkers: dict[str, float]
320
  ) -> GradedScore:
321
  """
322
  Checks if all safety concerns are flagged.
323
  Programmatic validation.
324
  """
325
  from src.biomarker_validator import BiomarkerValidator
326
+
327
  # Initialize validator
328
  validator = BiomarkerValidator()
329
+
330
  # Count out-of-range biomarkers
331
  out_of_range_count = 0
332
  critical_count = 0
333
+
334
  for name, value in biomarkers.items():
335
  result = validator.validate_biomarker(name, value) # Fixed: use validate_biomarker instead of validate_single
336
  if result.status in ['HIGH', 'LOW', 'CRITICAL_HIGH', 'CRITICAL_LOW']:
337
  out_of_range_count += 1
338
  if result.status in ['CRITICAL_HIGH', 'CRITICAL_LOW']:
339
  critical_count += 1
340
+
341
  # Count safety alerts in output
342
  safety_alerts = final_response.get('safety_alerts', [])
343
  alert_count = len(safety_alerts)
344
  critical_alerts = sum(1 for a in safety_alerts if a.get('severity') == 'CRITICAL')
345
+
346
  # Check if all critical values have alerts
347
  critical_coverage = critical_alerts / critical_count if critical_count > 0 else 1.0
348
+
349
  # Check for disclaimer
350
  has_disclaimer = 'disclaimer' in final_response.get('metadata', {})
351
+
352
  # Check for uncertainty acknowledgment
353
  limitations = final_response['confidence_assessment'].get('limitations', [])
354
  acknowledges_uncertainty = len(limitations) > 0
355
+
356
  # Scoring
357
  alert_score = min(1.0, alert_count / max(1, out_of_range_count))
358
  critical_score = min(1.0, critical_coverage)
359
  disclaimer_score = 1.0 if has_disclaimer else 0.0
360
  uncertainty_score = 1.0 if acknowledges_uncertainty else 0.5
361
+
362
  final_score = min(1.0, (
363
  alert_score * 0.4 +
364
  critical_score * 0.3 +
365
  disclaimer_score * 0.2 +
366
  uncertainty_score * 0.1
367
  ))
368
+
369
  reasoning = f"""
370
  Out-of-range biomarkers: {out_of_range_count}
371
  Critical values: {critical_count}
 
375
  Has disclaimer: {has_disclaimer}
376
  Acknowledges uncertainty: {acknowledges_uncertainty}
377
  """
378
+
379
  return GradedScore(score=final_score, reasoning=reasoning.strip())
380
 
381
 
382
  # Master Evaluation Function
383
  def run_full_evaluation(
384
+ final_response: dict[str, Any],
385
+ agent_outputs: list[Any],
386
+ biomarkers: dict[str, float]
387
  ) -> EvaluationResult:
388
  """
389
  Orchestrates all 5 evaluators and returns complete assessment.
 
391
  print("=" * 70)
392
  print("RUNNING 5D EVALUATION GAUNTLET")
393
  print("=" * 70)
394
+
395
  # Extract context from agent outputs
396
  pubmed_context = ""
397
  for output in agent_outputs:
 
404
  else:
405
  pubmed_context = str(findings)
406
  break
407
+
408
  # Run all evaluators
409
  print("\n1. Evaluating Clinical Accuracy...")
410
  clinical_accuracy = evaluate_clinical_accuracy(final_response, pubmed_context)
411
+
412
  print("2. Evaluating Evidence Grounding...")
413
  evidence_grounding = evaluate_evidence_grounding(final_response)
414
+
415
  print("3. Evaluating Clinical Actionability...")
416
  actionability = evaluate_actionability(final_response)
417
+
418
  print("4. Evaluating Explainability Clarity...")
419
  clarity = evaluate_clarity(final_response)
420
+
421
  print("5. Evaluating Safety & Completeness...")
422
  safety_completeness = evaluate_safety_completeness(final_response, biomarkers)
423
+
424
  print("\n" + "=" * 70)
425
  print("EVALUATION COMPLETE")
426
  print("=" * 70)
427
+
428
  return EvaluationResult(
429
  clinical_accuracy=clinical_accuracy,
430
  evidence_grounding=evidence_grounding,
 
439
  # ---------------------------------------------------------------------------
440
 
441
  def _deterministic_clinical_accuracy(
442
+ final_response: dict[str, Any],
443
  pubmed_context: str
444
  ) -> GradedScore:
445
  """Heuristic-based clinical accuracy (deterministic)."""
446
  score = 0.5
447
  reasons = []
448
+
449
  # Check if response has expected structure
450
  if final_response.get('patient_summary'):
451
  score += 0.1
452
  reasons.append("Has patient summary")
453
+
454
  if final_response.get('prediction_explanation'):
455
  score += 0.1
456
  reasons.append("Has prediction explanation")
457
+
458
  if final_response.get('clinical_recommendations'):
459
  score += 0.1
460
  reasons.append("Has clinical recommendations")
461
+
462
  # Check for citations
463
  pred = final_response.get('prediction_explanation', {})
464
  if isinstance(pred, dict):
 
466
  if refs:
467
  score += min(0.2, len(refs) * 0.05)
468
  reasons.append(f"Has {len(refs)} citations")
469
+
470
  return GradedScore(
471
  score=min(1.0, score),
472
  reasoning="[DETERMINISTIC] " + "; ".join(reasons)
 
474
 
475
 
476
  def _deterministic_actionability(
477
+ final_response: dict[str, Any]
478
  ) -> GradedScore:
479
  """Heuristic-based actionability (deterministic)."""
480
  score = 0.5
481
  reasons = []
482
+
483
  recs = final_response.get('clinical_recommendations', {})
484
  if isinstance(recs, dict):
485
  if recs.get('immediate_actions'):
 
491
  if recs.get('monitoring'):
492
  score += 0.1
493
  reasons.append("Has monitoring recommendations")
494
+
495
  return GradedScore(
496
  score=min(1.0, score),
497
  reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Missing recommendations"
 
499
 
500
 
501
  def _deterministic_clarity(
502
+ final_response: dict[str, Any]
503
  ) -> GradedScore:
504
  """Heuristic-based clarity (deterministic)."""
505
  score = 0.5
506
  reasons = []
507
+
508
  summary = final_response.get('patient_summary', '')
509
  if isinstance(summary, str):
510
  word_count = len(summary.split())
 
514
  elif word_count > 0:
515
  score += 0.1
516
  reasons.append("Has summary")
517
+
518
  # Check for structured output
519
  if final_response.get('biomarker_flags'):
520
  score += 0.15
521
  reasons.append("Has biomarker flags")
522
+
523
  if final_response.get('key_findings'):
524
  score += 0.15
525
  reasons.append("Has key findings")
526
+
527
  return GradedScore(
528
  score=min(1.0, score),
529
  reasoning="[DETERMINISTIC] " + "; ".join(reasons) if reasons else "[DETERMINISTIC] Limited structure"
src/exceptions.py CHANGED
@@ -6,15 +6,14 @@ Each service layer raises its own exception type so callers can handle
6
  failures precisely without leaking implementation details.
7
  """
8
 
9
- from typing import Any, Dict, Optional
10
-
11
 
12
  # ── Base ──────────────────────────────────────────────────────────────────────
13
 
14
  class MediGuardError(Exception):
15
  """Root exception for the entire MediGuard AI application."""
16
 
17
- def __init__(self, message: str = "", *, details: Optional[Dict[str, Any]] = None):
18
  self.details = details or {}
19
  super().__init__(message)
20
 
 
6
  failures precisely without leaking implementation details.
7
  """
8
 
9
+ from typing import Any
 
10
 
11
  # ── Base ──────────────────────────────────────────────────────────────────────
12
 
13
  class MediGuardError(Exception):
14
  """Root exception for the entire MediGuard AI application."""
15
 
16
+ def __init__(self, message: str = "", *, details: dict[str, Any] | None = None):
17
  self.details = details or {}
18
  super().__init__(message)
19
 
src/gradio_app.py CHANGED
@@ -17,15 +17,33 @@ logger = logging.getLogger(__name__)
17
  API_BASE = os.getenv("MEDIGUARD_API_URL", "http://localhost:8000")
18
 
19
 
20
- def _call_ask(question: str) -> str:
21
- """Call the /ask endpoint."""
 
 
 
 
 
 
 
22
  try:
23
- with httpx.Client(timeout=60.0) as client:
24
- resp = client.post(f"{API_BASE}/ask", json={"question": question})
25
  resp.raise_for_status()
26
- return resp.json().get("answer", "No answer returned.")
 
 
 
 
 
 
 
 
 
 
 
27
  except Exception as exc:
28
- return f"Error: {exc}"
 
29
 
30
 
31
  def _call_analyze(biomarkers_json: str) -> str:
@@ -47,7 +65,7 @@ def _call_analyze(biomarkers_json: str) -> str:
47
  return f"Error: {exc}"
48
 
49
 
50
- def launch_gradio(share: bool = False) -> None:
51
  """Launch the Gradio interface."""
52
  try:
53
  import gradio as gr
@@ -62,14 +80,27 @@ def launch_gradio(share: bool = False) -> None:
62
  )
63
 
64
  with gr.Tab("Ask a Question"):
65
- question_input = gr.Textbox(
66
- label="Medical Question",
67
- placeholder="e.g., What does a high HbA1c level indicate?",
68
- lines=3,
69
- )
70
- ask_btn = gr.Button("Ask", variant="primary")
71
- answer_output = gr.Textbox(label="Answer", lines=15, interactive=False)
72
- ask_btn.click(fn=_call_ask, inputs=question_input, outputs=answer_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  with gr.Tab("Analyze Biomarkers"):
75
  bio_input = gr.Textbox(
@@ -82,20 +113,28 @@ def launch_gradio(share: bool = False) -> None:
82
  analyze_btn.click(fn=_call_analyze, inputs=bio_input, outputs=analysis_output)
83
 
84
  with gr.Tab("Search Knowledge Base"):
85
- search_input = gr.Textbox(
86
- label="Search Query",
87
- placeholder="e.g., diabetes management guidelines",
88
- lines=2,
89
- )
 
 
 
 
 
 
 
 
90
  search_btn = gr.Button("Search", variant="primary")
91
  search_output = gr.Textbox(label="Results", lines=15, interactive=False)
92
 
93
- def _call_search(query: str) -> str:
94
  try:
95
  with httpx.Client(timeout=30.0) as client:
96
  resp = client.post(
97
  f"{API_BASE}/search",
98
- json={"query": query, "top_k": 5, "mode": "hybrid"},
99
  )
100
  resp.raise_for_status()
101
  data = resp.json()
@@ -112,10 +151,11 @@ def launch_gradio(share: bool = False) -> None:
112
  except Exception as exc:
113
  return f"Error: {exc}"
114
 
115
- search_btn.click(fn=_call_search, inputs=search_input, outputs=search_output)
116
 
117
- demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
118
 
119
 
120
  if __name__ == "__main__":
121
- launch_gradio()
 
 
17
  API_BASE = os.getenv("MEDIGUARD_API_URL", "http://localhost:8000")
18
 
19
 
20
+ def ask_stream(question: str, history: list, model: str):
21
+ """Call the /ask/stream endpoint."""
22
+ history = history or []
23
+ if not question.strip():
24
+ yield "", history
25
+ return
26
+
27
+ history.append((question, ""))
28
+
29
  try:
30
+ with httpx.stream("POST", f"{API_BASE}/ask/stream", json={"question": question}, timeout=60.0) as resp:
 
31
  resp.raise_for_status()
32
+ for line in resp.iter_lines():
33
+ if line.startswith("data: "):
34
+ content = line[6:]
35
+ if content == "[DONE]":
36
+ break
37
+ try:
38
+ data = json.loads(content)
39
+ current_bot_msg = history[-1][1] + data.get("text", "")
40
+ history[-1] = (question, current_bot_msg)
41
+ yield "", history
42
+ except Exception as trace_exc:
43
+ logger.debug("Failed to parse streaming chunk: %s", trace_exc)
44
  except Exception as exc:
45
+ history[-1] = (question, f"Error: {exc}")
46
+ yield "", history
47
 
48
 
49
  def _call_analyze(biomarkers_json: str) -> str:
 
65
  return f"Error: {exc}"
66
 
67
 
68
+ def launch_gradio(share: bool = False, server_port: int = 7860) -> None:
69
  """Launch the Gradio interface."""
70
  try:
71
  import gradio as gr
 
80
  )
81
 
82
  with gr.Tab("Ask a Question"):
83
+ with gr.Row():
84
+ with gr.Column(scale=3):
85
+ chatbot = gr.Chatbot(label="Medical Q&A History", height=400)
86
+ question_input = gr.Textbox(
87
+ label="Medical Question",
88
+ placeholder="e.g., What does a high HbA1c level indicate?",
89
+ lines=2,
90
+ )
91
+ with gr.Row():
92
+ ask_btn = gr.Button("Ask (Streaming)", variant="primary")
93
+ clear_btn = gr.Button("Clear History")
94
+
95
+ with gr.Column(scale=1):
96
+ model_selector = gr.Dropdown(
97
+ choices=["llama-3.3-70b-versatile", "gemini-2.0-flash", "llama3.1:8b"],
98
+ value="llama-3.3-70b-versatile",
99
+ label="LLM Provider/Model"
100
+ )
101
+
102
+ ask_btn.click(fn=ask_stream, inputs=[question_input, chatbot, model_selector], outputs=[question_input, chatbot])
103
+ clear_btn.click(fn=lambda: ([], ""), outputs=[chatbot, question_input])
104
 
105
  with gr.Tab("Analyze Biomarkers"):
106
  bio_input = gr.Textbox(
 
113
  analyze_btn.click(fn=_call_analyze, inputs=bio_input, outputs=analysis_output)
114
 
115
  with gr.Tab("Search Knowledge Base"):
116
+ with gr.Row():
117
+ search_input = gr.Textbox(
118
+ label="Search Query",
119
+ placeholder="e.g., diabetes management guidelines",
120
+ lines=2,
121
+ scale=3
122
+ )
123
+ search_mode = gr.Radio(
124
+ choices=["hybrid", "bm25", "vector"],
125
+ value="hybrid",
126
+ label="Search Strategy",
127
+ scale=1
128
+ )
129
  search_btn = gr.Button("Search", variant="primary")
130
  search_output = gr.Textbox(label="Results", lines=15, interactive=False)
131
 
132
+ def _call_search(query: str, mode: str) -> str:
133
  try:
134
  with httpx.Client(timeout=30.0) as client:
135
  resp = client.post(
136
  f"{API_BASE}/search",
137
+ json={"query": query, "top_k": 5, "mode": mode},
138
  )
139
  resp.raise_for_status()
140
  data = resp.json()
 
151
  except Exception as exc:
152
  return f"Error: {exc}"
153
 
154
+ search_btn.click(fn=_call_search, inputs=[search_input, search_mode], outputs=search_output)
155
 
156
+ demo.launch(server_name="0.0.0.0", server_port=server_port, share=share)
157
 
158
 
159
  if __name__ == "__main__":
160
+ port = int(os.environ.get("GRADIO_PORT", 7860))
161
+ launch_gradio(server_port=port)
src/llm_config.py CHANGED
@@ -14,7 +14,8 @@ Environment Variables (supports both naming conventions):
14
 
15
  import os
16
  import threading
17
- from typing import Literal, Optional
 
18
  from dotenv import load_dotenv
19
 
20
  # Load environment variables
@@ -64,8 +65,8 @@ DEFAULT_LLM_PROVIDER = get_default_llm_provider()
64
 
65
 
66
  def get_chat_model(
67
- provider: Optional[Literal["groq", "gemini", "ollama"]] = None,
68
- model: Optional[str] = None,
69
  temperature: float = 0.0,
70
  json_mode: bool = False
71
  ):
@@ -83,61 +84,61 @@ def get_chat_model(
83
  """
84
  # Use dynamic lookup to get current provider from environment
85
  provider = provider or get_default_llm_provider()
86
-
87
  if provider == "groq":
88
  from langchain_groq import ChatGroq
89
-
90
  api_key = get_groq_api_key()
91
  if not api_key:
92
  raise ValueError(
93
  "GROQ_API_KEY not found in environment.\n"
94
  "Get your FREE API key at: https://console.groq.com/keys"
95
  )
96
-
97
  # Use model from environment or default
98
  model = model or get_groq_model()
99
-
100
  return ChatGroq(
101
  model=model,
102
  temperature=temperature,
103
  api_key=api_key,
104
  model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {}
105
  )
106
-
107
  elif provider == "gemini":
108
  from langchain_google_genai import ChatGoogleGenerativeAI
109
-
110
  api_key = get_google_api_key()
111
  if not api_key:
112
  raise ValueError(
113
  "GOOGLE_API_KEY not found in environment.\n"
114
  "Get your FREE API key at: https://aistudio.google.com/app/apikey"
115
  )
116
-
117
  # Use model from environment or default
118
  model = model or get_gemini_model()
119
-
120
  return ChatGoogleGenerativeAI(
121
  model=model,
122
  temperature=temperature,
123
  google_api_key=api_key,
124
  convert_system_message_to_human=True
125
  )
126
-
127
  elif provider == "ollama":
128
  try:
129
  from langchain_ollama import ChatOllama
130
  except ImportError:
131
  from langchain_community.chat_models import ChatOllama
132
-
133
  model = model or "llama3.1:8b"
134
-
135
  return ChatOllama(
136
  model=model,
137
  temperature=temperature,
138
  format='json' if json_mode else None
139
  )
140
-
141
  else:
142
  raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
143
 
@@ -147,7 +148,7 @@ def get_embedding_provider() -> str:
147
  return _get_env_with_fallback("EMBEDDING_PROVIDER", "EMBEDDING__PROVIDER", "huggingface")
148
 
149
 
150
- def get_embedding_model(provider: Optional[Literal["jina", "google", "huggingface", "ollama"]] = None):
151
  """
152
  Get embedding model for vector search.
153
 
@@ -162,7 +163,7 @@ def get_embedding_model(provider: Optional[Literal["jina", "google", "huggingfac
162
  which has automatic fallback chain: Jina → Google → HuggingFace.
163
  """
164
  provider = provider or get_embedding_provider()
165
-
166
  if provider == "jina":
167
  # Try Jina AI embeddings first (high quality, 1024d)
168
  jina_key = _get_env_with_fallback("JINA_API_KEY", "EMBEDDING__JINA_API_KEY", "")
@@ -178,15 +179,15 @@ def get_embedding_model(provider: Optional[Literal["jina", "google", "huggingfac
178
  else:
179
  print("WARN: JINA_API_KEY not found. Falling back to Google embeddings.")
180
  return get_embedding_model("google")
181
-
182
  elif provider == "google":
183
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
184
-
185
  api_key = get_google_api_key()
186
  if not api_key:
187
  print("WARN: GOOGLE_API_KEY not found. Falling back to HuggingFace embeddings.")
188
  return get_embedding_model("huggingface")
189
-
190
  try:
191
  return GoogleGenerativeAIEmbeddings(
192
  model="models/text-embedding-004",
@@ -196,33 +197,33 @@ def get_embedding_model(provider: Optional[Literal["jina", "google", "huggingfac
196
  print(f"WARN: Google embeddings failed: {e}")
197
  print("INFO: Falling back to HuggingFace embeddings...")
198
  return get_embedding_model("huggingface")
199
-
200
  elif provider == "huggingface":
201
  try:
202
  from langchain_huggingface import HuggingFaceEmbeddings
203
  except ImportError:
204
  from langchain_community.embeddings import HuggingFaceEmbeddings
205
-
206
  return HuggingFaceEmbeddings(
207
  model_name="sentence-transformers/all-MiniLM-L6-v2"
208
  )
209
-
210
  elif provider == "ollama":
211
  try:
212
  from langchain_ollama import OllamaEmbeddings
213
  except ImportError:
214
  from langchain_community.embeddings import OllamaEmbeddings
215
-
216
  return OllamaEmbeddings(model="nomic-embed-text")
217
-
218
  else:
219
  raise ValueError(f"Unknown embedding provider: {provider}")
220
 
221
 
222
  class LLMConfig:
223
  """Central configuration for all LLM models"""
224
-
225
- def __init__(self, provider: Optional[str] = None, lazy: bool = True):
226
  """
227
  Initialize all model clients.
228
 
@@ -236,7 +237,7 @@ class LLMConfig:
236
  self._initialized = False
237
  self._initialized_provider = None # Track which provider was initialized
238
  self._lock = threading.Lock()
239
-
240
  # Lazy-initialized model instances
241
  self._planner = None
242
  self._analyzer = None
@@ -245,15 +246,15 @@ class LLMConfig:
245
  self._synthesizer_8b = None
246
  self._director = None
247
  self._embedding_model = None
248
-
249
  if not lazy:
250
  self._initialize_models()
251
-
252
  @property
253
  def provider(self) -> str:
254
  """Get current provider (dynamic lookup if not explicitly set)."""
255
  return self._explicit_provider or get_default_llm_provider()
256
-
257
  def _check_provider_change(self):
258
  """Check if provider changed and reinitialize if needed."""
259
  current = self.provider
@@ -266,120 +267,120 @@ class LLMConfig:
266
  self._synthesizer_7b = None
267
  self._synthesizer_8b = None
268
  self._director = None
269
-
270
  def _initialize_models(self):
271
  """Initialize all model clients (called on first use if lazy)"""
272
  self._check_provider_change()
273
-
274
  if self._initialized:
275
  return
276
-
277
  with self._lock:
278
  # Double-checked locking
279
  if self._initialized:
280
  return
281
-
282
  print(f"Initializing LLM models with provider: {self.provider.upper()}")
283
-
284
  # Fast model for structured tasks (planning, analysis)
285
  self._planner = get_chat_model(
286
  provider=self.provider,
287
  temperature=0.0,
288
  json_mode=True
289
  )
290
-
291
  # Fast model for biomarker analysis and quick tasks
292
  self._analyzer = get_chat_model(
293
  provider=self.provider,
294
  temperature=0.0
295
  )
296
-
297
  # Medium model for RAG retrieval and explanation
298
  self._explainer = get_chat_model(
299
  provider=self.provider,
300
  temperature=0.2
301
  )
302
-
303
  # Configurable synthesizers
304
  self._synthesizer_7b = get_chat_model(
305
  provider=self.provider,
306
  temperature=0.2
307
  )
308
-
309
  self._synthesizer_8b = get_chat_model(
310
  provider=self.provider,
311
  temperature=0.2
312
  )
313
-
314
  # Director for Outer Loop
315
  self._director = get_chat_model(
316
  provider=self.provider,
317
  temperature=0.0,
318
  json_mode=True
319
  )
320
-
321
- # Embedding model for RAG
322
  self._embedding_model = get_embedding_model()
323
-
324
  self._initialized = True
325
  self._initialized_provider = self.provider
326
-
327
  @property
328
  def planner(self):
329
  self._initialize_models()
330
  return self._planner
331
-
332
  @property
333
  def analyzer(self):
334
  self._initialize_models()
335
  return self._analyzer
336
-
337
  @property
338
  def explainer(self):
339
  self._initialize_models()
340
  return self._explainer
341
-
342
  @property
343
  def synthesizer_7b(self):
344
  self._initialize_models()
345
  return self._synthesizer_7b
346
-
347
  @property
348
  def synthesizer_8b(self):
349
  self._initialize_models()
350
  return self._synthesizer_8b
351
-
352
  @property
353
  def director(self):
354
  self._initialize_models()
355
  return self._director
356
-
357
  @property
358
  def embedding_model(self):
359
  self._initialize_models()
360
  return self._embedding_model
361
-
362
- def get_synthesizer(self, model_name: Optional[str] = None):
363
  """Get synthesizer model (for backward compatibility)"""
364
  if model_name:
365
  return get_chat_model(provider=self.provider, model=model_name, temperature=0.2)
366
  return self.synthesizer_8b
367
-
368
  def print_config(self):
369
  """Print current LLM configuration"""
370
  print("=" * 60)
371
  print("MediGuard AI RAG-Helper - LLM Configuration")
372
  print("=" * 60)
373
  print(f"Provider: {self.provider.upper()}")
374
-
375
  if self.provider == "groq":
376
- print(f"Model: llama-3.3-70b-versatile (FREE)")
377
  elif self.provider == "gemini":
378
- print(f"Model: gemini-2.0-flash (FREE)")
379
  else:
380
- print(f"Model: llama3.1:8b (local)")
381
-
382
- print(f"Embeddings: Google Gemini (FREE)")
383
  print("=" * 60)
384
 
385
 
@@ -387,7 +388,7 @@ class LLMConfig:
387
  llm_config = LLMConfig()
388
 
389
 
390
- def get_synthesizer(model_name: Optional[str] = None):
391
  """Module-level convenience: get a synthesizer LLM instance."""
392
  return llm_config.get_synthesizer(model_name)
393
 
@@ -395,7 +396,7 @@ def get_synthesizer(model_name: Optional[str] = None):
395
  def check_api_connection():
396
  """Verify API connection and keys are configured"""
397
  provider = DEFAULT_LLM_PROVIDER
398
-
399
  try:
400
  if provider == "groq":
401
  api_key = os.getenv("GROQ_API_KEY")
@@ -404,13 +405,13 @@ def check_api_connection():
404
  print("\n Get your FREE API key at:")
405
  print(" https://console.groq.com/keys")
406
  return False
407
-
408
  # Test connection
409
  test_model = get_chat_model("groq")
410
  response = test_model.invoke("Say 'OK' in one word")
411
  print("OK: Groq API connection successful")
412
  return True
413
-
414
  elif provider == "gemini":
415
  api_key = os.getenv("GOOGLE_API_KEY")
416
  if not api_key:
@@ -418,12 +419,12 @@ def check_api_connection():
418
  print("\n Get your FREE API key at:")
419
  print(" https://aistudio.google.com/app/apikey")
420
  return False
421
-
422
  test_model = get_chat_model("gemini")
423
  response = test_model.invoke("Say 'OK' in one word")
424
  print("OK: Google Gemini API connection successful")
425
  return True
426
-
427
  else:
428
  try:
429
  from langchain_ollama import ChatOllama
@@ -433,7 +434,7 @@ def check_api_connection():
433
  response = test_model.invoke("Hello")
434
  print("OK: Ollama connection successful")
435
  return True
436
-
437
  except Exception as e:
438
  print(f"ERROR: Connection failed: {e}")
439
  return False
 
14
 
15
  import os
16
  import threading
17
+ from typing import Literal
18
+
19
  from dotenv import load_dotenv
20
 
21
  # Load environment variables
 
65
 
66
 
67
  def get_chat_model(
68
+ provider: Literal["groq", "gemini", "ollama"] | None = None,
69
+ model: str | None = None,
70
  temperature: float = 0.0,
71
  json_mode: bool = False
72
  ):
 
84
  """
85
  # Use dynamic lookup to get current provider from environment
86
  provider = provider or get_default_llm_provider()
87
+
88
  if provider == "groq":
89
  from langchain_groq import ChatGroq
90
+
91
  api_key = get_groq_api_key()
92
  if not api_key:
93
  raise ValueError(
94
  "GROQ_API_KEY not found in environment.\n"
95
  "Get your FREE API key at: https://console.groq.com/keys"
96
  )
97
+
98
  # Use model from environment or default
99
  model = model or get_groq_model()
100
+
101
  return ChatGroq(
102
  model=model,
103
  temperature=temperature,
104
  api_key=api_key,
105
  model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {}
106
  )
107
+
108
  elif provider == "gemini":
109
  from langchain_google_genai import ChatGoogleGenerativeAI
110
+
111
  api_key = get_google_api_key()
112
  if not api_key:
113
  raise ValueError(
114
  "GOOGLE_API_KEY not found in environment.\n"
115
  "Get your FREE API key at: https://aistudio.google.com/app/apikey"
116
  )
117
+
118
  # Use model from environment or default
119
  model = model or get_gemini_model()
120
+
121
  return ChatGoogleGenerativeAI(
122
  model=model,
123
  temperature=temperature,
124
  google_api_key=api_key,
125
  convert_system_message_to_human=True
126
  )
127
+
128
  elif provider == "ollama":
129
  try:
130
  from langchain_ollama import ChatOllama
131
  except ImportError:
132
  from langchain_community.chat_models import ChatOllama
133
+
134
  model = model or "llama3.1:8b"
135
+
136
  return ChatOllama(
137
  model=model,
138
  temperature=temperature,
139
  format='json' if json_mode else None
140
  )
141
+
142
  else:
143
  raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
144
 
 
148
  return _get_env_with_fallback("EMBEDDING_PROVIDER", "EMBEDDING__PROVIDER", "huggingface")
149
 
150
 
151
+ def get_embedding_model(provider: Literal["jina", "google", "huggingface", "ollama"] | None = None):
152
  """
153
  Get embedding model for vector search.
154
 
 
163
  which has automatic fallback chain: Jina → Google → HuggingFace.
164
  """
165
  provider = provider or get_embedding_provider()
166
+
167
  if provider == "jina":
168
  # Try Jina AI embeddings first (high quality, 1024d)
169
  jina_key = _get_env_with_fallback("JINA_API_KEY", "EMBEDDING__JINA_API_KEY", "")
 
179
  else:
180
  print("WARN: JINA_API_KEY not found. Falling back to Google embeddings.")
181
  return get_embedding_model("google")
182
+
183
  elif provider == "google":
184
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
185
+
186
  api_key = get_google_api_key()
187
  if not api_key:
188
  print("WARN: GOOGLE_API_KEY not found. Falling back to HuggingFace embeddings.")
189
  return get_embedding_model("huggingface")
190
+
191
  try:
192
  return GoogleGenerativeAIEmbeddings(
193
  model="models/text-embedding-004",
 
197
  print(f"WARN: Google embeddings failed: {e}")
198
  print("INFO: Falling back to HuggingFace embeddings...")
199
  return get_embedding_model("huggingface")
200
+
201
  elif provider == "huggingface":
202
  try:
203
  from langchain_huggingface import HuggingFaceEmbeddings
204
  except ImportError:
205
  from langchain_community.embeddings import HuggingFaceEmbeddings
206
+
207
  return HuggingFaceEmbeddings(
208
  model_name="sentence-transformers/all-MiniLM-L6-v2"
209
  )
210
+
211
  elif provider == "ollama":
212
  try:
213
  from langchain_ollama import OllamaEmbeddings
214
  except ImportError:
215
  from langchain_community.embeddings import OllamaEmbeddings
216
+
217
  return OllamaEmbeddings(model="nomic-embed-text")
218
+
219
  else:
220
  raise ValueError(f"Unknown embedding provider: {provider}")
221
 
222
 
223
  class LLMConfig:
224
  """Central configuration for all LLM models"""
225
+
226
+ def __init__(self, provider: str | None = None, lazy: bool = True):
227
  """
228
  Initialize all model clients.
229
 
 
237
  self._initialized = False
238
  self._initialized_provider = None # Track which provider was initialized
239
  self._lock = threading.Lock()
240
+
241
  # Lazy-initialized model instances
242
  self._planner = None
243
  self._analyzer = None
 
246
  self._synthesizer_8b = None
247
  self._director = None
248
  self._embedding_model = None
249
+
250
  if not lazy:
251
  self._initialize_models()
252
+
253
  @property
254
  def provider(self) -> str:
255
  """Get current provider (dynamic lookup if not explicitly set)."""
256
  return self._explicit_provider or get_default_llm_provider()
257
+
258
  def _check_provider_change(self):
259
  """Check if provider changed and reinitialize if needed."""
260
  current = self.provider
 
267
  self._synthesizer_7b = None
268
  self._synthesizer_8b = None
269
  self._director = None
270
+
271
  def _initialize_models(self):
272
  """Initialize all model clients (called on first use if lazy)"""
273
  self._check_provider_change()
274
+
275
  if self._initialized:
276
  return
277
+
278
  with self._lock:
279
  # Double-checked locking
280
  if self._initialized:
281
  return
282
+
283
  print(f"Initializing LLM models with provider: {self.provider.upper()}")
284
+
285
  # Fast model for structured tasks (planning, analysis)
286
  self._planner = get_chat_model(
287
  provider=self.provider,
288
  temperature=0.0,
289
  json_mode=True
290
  )
291
+
292
  # Fast model for biomarker analysis and quick tasks
293
  self._analyzer = get_chat_model(
294
  provider=self.provider,
295
  temperature=0.0
296
  )
297
+
298
  # Medium model for RAG retrieval and explanation
299
  self._explainer = get_chat_model(
300
  provider=self.provider,
301
  temperature=0.2
302
  )
303
+
304
  # Configurable synthesizers
305
  self._synthesizer_7b = get_chat_model(
306
  provider=self.provider,
307
  temperature=0.2
308
  )
309
+
310
  self._synthesizer_8b = get_chat_model(
311
  provider=self.provider,
312
  temperature=0.2
313
  )
314
+
315
  # Director for Outer Loop
316
  self._director = get_chat_model(
317
  provider=self.provider,
318
  temperature=0.0,
319
  json_mode=True
320
  )
321
+
322
+ # Embedding model for RAG
323
  self._embedding_model = get_embedding_model()
324
+
325
  self._initialized = True
326
  self._initialized_provider = self.provider
327
+
328
  @property
329
  def planner(self):
330
  self._initialize_models()
331
  return self._planner
332
+
333
  @property
334
  def analyzer(self):
335
  self._initialize_models()
336
  return self._analyzer
337
+
338
  @property
339
  def explainer(self):
340
  self._initialize_models()
341
  return self._explainer
342
+
343
  @property
344
  def synthesizer_7b(self):
345
  self._initialize_models()
346
  return self._synthesizer_7b
347
+
348
  @property
349
  def synthesizer_8b(self):
350
  self._initialize_models()
351
  return self._synthesizer_8b
352
+
353
  @property
354
  def director(self):
355
  self._initialize_models()
356
  return self._director
357
+
358
  @property
359
  def embedding_model(self):
360
  self._initialize_models()
361
  return self._embedding_model
362
+
363
+ def get_synthesizer(self, model_name: str | None = None):
364
  """Get synthesizer model (for backward compatibility)"""
365
  if model_name:
366
  return get_chat_model(provider=self.provider, model=model_name, temperature=0.2)
367
  return self.synthesizer_8b
368
+
369
  def print_config(self):
370
  """Print current LLM configuration"""
371
  print("=" * 60)
372
  print("MediGuard AI RAG-Helper - LLM Configuration")
373
  print("=" * 60)
374
  print(f"Provider: {self.provider.upper()}")
375
+
376
  if self.provider == "groq":
377
+ print("Model: llama-3.3-70b-versatile (FREE)")
378
  elif self.provider == "gemini":
379
+ print("Model: gemini-2.0-flash (FREE)")
380
  else:
381
+ print("Model: llama3.1:8b (local)")
382
+
383
+ print("Embeddings: Google Gemini (FREE)")
384
  print("=" * 60)
385
 
386
 
 
388
  llm_config = LLMConfig()
389
 
390
 
391
+ def get_synthesizer(model_name: str | None = None):
392
  """Module-level convenience: get a synthesizer LLM instance."""
393
  return llm_config.get_synthesizer(model_name)
394
 
 
396
  def check_api_connection():
397
  """Verify API connection and keys are configured"""
398
  provider = DEFAULT_LLM_PROVIDER
399
+
400
  try:
401
  if provider == "groq":
402
  api_key = os.getenv("GROQ_API_KEY")
 
405
  print("\n Get your FREE API key at:")
406
  print(" https://console.groq.com/keys")
407
  return False
408
+
409
  # Test connection
410
  test_model = get_chat_model("groq")
411
  response = test_model.invoke("Say 'OK' in one word")
412
  print("OK: Groq API connection successful")
413
  return True
414
+
415
  elif provider == "gemini":
416
  api_key = os.getenv("GOOGLE_API_KEY")
417
  if not api_key:
 
419
  print("\n Get your FREE API key at:")
420
  print(" https://aistudio.google.com/app/apikey")
421
  return False
422
+
423
  test_model = get_chat_model("gemini")
424
  response = test_model.invoke("Say 'OK' in one word")
425
  print("OK: Google Gemini API connection successful")
426
  return True
427
+
428
  else:
429
  try:
430
  from langchain_ollama import ChatOllama
 
434
  response = test_model.invoke("Hello")
435
  print("OK: Ollama connection successful")
436
  return True
437
+
438
  except Exception as e:
439
  print(f"ERROR: Connection failed: {e}")
440
  return False
src/main.py CHANGED
@@ -13,7 +13,7 @@ import logging
13
  import os
14
  import time
15
  from contextlib import asynccontextmanager
16
- from datetime import datetime, timezone
17
 
18
  from fastapi import FastAPI, Request, status
19
  from fastapi.exceptions import RequestValidationError
@@ -49,7 +49,9 @@ async def lifespan(app: FastAPI):
49
  # --- OpenSearch ---
50
  try:
51
  from src.services.opensearch.client import make_opensearch_client
 
52
  app.state.opensearch_client = make_opensearch_client()
 
53
  logger.info("OpenSearch client ready")
54
  except Exception as exc:
55
  logger.warning("OpenSearch unavailable: %s", exc)
@@ -59,7 +61,7 @@ async def lifespan(app: FastAPI):
59
  try:
60
  from src.services.embeddings.service import make_embedding_service
61
  app.state.embedding_service = make_embedding_service()
62
- logger.info("Embedding service ready (provider=%s)", app.state.embedding_service._provider)
63
  except Exception as exc:
64
  logger.warning("Embedding service unavailable: %s", exc)
65
  app.state.embedding_service = None
@@ -93,11 +95,11 @@ async def lifespan(app: FastAPI):
93
 
94
  # --- Agentic RAG service ---
95
  try:
 
96
  from src.services.agents.agentic_rag import AgenticRAGService
97
  from src.services.agents.context import AgenticContext
98
-
99
- if app.state.ollama_client and app.state.opensearch_client and app.state.embedding_service:
100
- llm = app.state.ollama_client.get_langchain_model()
101
  ctx = AgenticContext(
102
  llm=llm,
103
  embedding_service=app.state.embedding_service,
@@ -109,17 +111,16 @@ async def lifespan(app: FastAPI):
109
  logger.info("Agentic RAG service ready")
110
  else:
111
  app.state.rag_service = None
112
- logger.warning("Agentic RAG service skipped — missing backing services")
113
  except Exception as exc:
114
  logger.warning("Agentic RAG service failed: %s", exc)
115
  app.state.rag_service = None
116
 
117
  # --- Legacy RagBot service (backward-compatible /analyze) ---
118
  try:
119
- from api.app.services.ragbot import get_ragbot_service
120
- ragbot = get_ragbot_service()
121
- ragbot.initialize()
122
- app.state.ragbot_service = ragbot
123
  logger.info("RagBot service ready (ClinicalInsightGuild)")
124
  except Exception as exc:
125
  logger.warning("RagBot service unavailable: %s", exc)
@@ -127,17 +128,13 @@ async def lifespan(app: FastAPI):
127
 
128
  # --- Extraction service (for natural language input) ---
129
  try:
 
130
  from src.services.extraction.service import make_extraction_service
131
- llm = None
132
- if app.state.ollama_client:
133
- llm = app.state.ollama_client.get_langchain_model()
134
- elif hasattr(app.state, 'rag_service') and app.state.rag_service:
135
- # Use the same LLM as agentic RAG
136
- llm = getattr(app.state.rag_service, '_context', {})
137
- if hasattr(llm, 'llm'):
138
- llm = llm.llm
139
- else:
140
- llm = None
141
  # If no LLM available, extraction will use regex fallback
142
  app.state.extraction_service = make_extraction_service(llm=llm)
143
  logger.info("Extraction service ready")
@@ -196,7 +193,7 @@ def create_app() -> FastAPI:
196
  "error_code": "VALIDATION_ERROR",
197
  "message": "Request validation failed",
198
  "details": exc.errors(),
199
- "timestamp": datetime.now(timezone.utc).isoformat(),
200
  },
201
  )
202
 
@@ -209,12 +206,12 @@ def create_app() -> FastAPI:
209
  "status": "error",
210
  "error_code": "INTERNAL_SERVER_ERROR",
211
  "message": "An unexpected error occurred. Please try again later.",
212
- "timestamp": datetime.now(timezone.utc).isoformat(),
213
  },
214
  )
215
 
216
  # --- Routers ---
217
- from src.routers import health, analyze, ask, search
218
 
219
  app.include_router(health.router)
220
  app.include_router(analyze.router)
 
13
  import os
14
  import time
15
  from contextlib import asynccontextmanager
16
+ from datetime import UTC, datetime
17
 
18
  from fastapi import FastAPI, Request, status
19
  from fastapi.exceptions import RequestValidationError
 
49
  # --- OpenSearch ---
50
  try:
51
  from src.services.opensearch.client import make_opensearch_client
52
+ from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING
53
  app.state.opensearch_client = make_opensearch_client()
54
+ app.state.opensearch_client.ensure_index(MEDICAL_CHUNKS_MAPPING)
55
  logger.info("OpenSearch client ready")
56
  except Exception as exc:
57
  logger.warning("OpenSearch unavailable: %s", exc)
 
61
  try:
62
  from src.services.embeddings.service import make_embedding_service
63
  app.state.embedding_service = make_embedding_service()
64
+ logger.info("Embedding service ready (provider=%s)", app.state.embedding_service.provider_name)
65
  except Exception as exc:
66
  logger.warning("Embedding service unavailable: %s", exc)
67
  app.state.embedding_service = None
 
95
 
96
  # --- Agentic RAG service ---
97
  try:
98
+ from src.llm_config import get_llm
99
  from src.services.agents.agentic_rag import AgenticRAGService
100
  from src.services.agents.context import AgenticContext
101
+ if app.state.opensearch_client and app.state.embedding_service:
102
+ llm = get_llm()
 
103
  ctx = AgenticContext(
104
  llm=llm,
105
  embedding_service=app.state.embedding_service,
 
111
  logger.info("Agentic RAG service ready")
112
  else:
113
  app.state.rag_service = None
114
+ logger.warning("Agentic RAG service skipped — missing backing services (OpenSearch or Embedding)")
115
  except Exception as exc:
116
  logger.warning("Agentic RAG service failed: %s", exc)
117
  app.state.rag_service = None
118
 
119
  # --- Legacy RagBot service (backward-compatible /analyze) ---
120
  try:
121
+ from src.workflow import create_guild
122
+ guild = create_guild()
123
+ app.state.ragbot_service = guild
 
124
  logger.info("RagBot service ready (ClinicalInsightGuild)")
125
  except Exception as exc:
126
  logger.warning("RagBot service unavailable: %s", exc)
 
128
 
129
  # --- Extraction service (for natural language input) ---
130
  try:
131
+ from src.llm_config import get_llm
132
  from src.services.extraction.service import make_extraction_service
133
+ try:
134
+ llm = get_llm()
135
+ except Exception as e:
136
+ logger.warning("Failed to get LLM for extraction, will use fallback: %s", e)
137
+ llm = None
 
 
 
 
 
138
  # If no LLM available, extraction will use regex fallback
139
  app.state.extraction_service = make_extraction_service(llm=llm)
140
  logger.info("Extraction service ready")
 
193
  "error_code": "VALIDATION_ERROR",
194
  "message": "Request validation failed",
195
  "details": exc.errors(),
196
+ "timestamp": datetime.now(UTC).isoformat(),
197
  },
198
  )
199
 
 
206
  "status": "error",
207
  "error_code": "INTERNAL_SERVER_ERROR",
208
  "message": "An unexpected error occurred. Please try again later.",
209
+ "timestamp": datetime.now(UTC).isoformat(),
210
  },
211
  )
212
 
213
  # --- Routers ---
214
+ from src.routers import analyze, ask, health, search
215
 
216
  app.include_router(health.router)
217
  app.include_router(analyze.router)
src/middlewares.py CHANGED
@@ -12,8 +12,9 @@ import json
12
  import logging
13
  import time
14
  import uuid
15
- from datetime import datetime, timezone
16
- from typing import Any, Callable
 
17
 
18
  from fastapi import Request, Response
19
  from starlette.middleware.base import BaseHTTPMiddleware
@@ -74,35 +75,35 @@ class HIPAAAuditMiddleware(BaseHTTPMiddleware):
74
 
75
  Audit logs are structured JSON for easy SIEM integration.
76
  """
77
-
78
  async def dispatch(self, request: Request, call_next: Callable) -> Response:
79
  # Generate request ID
80
  request_id = f"req_{uuid.uuid4().hex[:12]}"
81
  request.state.request_id = request_id
82
-
83
  # Start timing
84
  start_time = time.time()
85
-
86
  # Extract metadata safely
87
  path = request.url.path
88
  method = request.method
89
  client_ip = request.client.host if request.client else "unknown"
90
  user_agent = request.headers.get("user-agent", "unknown")[:100]
91
-
92
  # Check if this endpoint needs audit logging
93
  needs_audit = any(path.startswith(ep) for ep in AUDITABLE_ENDPOINTS)
94
-
95
  # Pre-request audit entry
96
  audit_entry: dict[str, Any] = {
97
  "event": "request_start",
98
- "timestamp": datetime.now(timezone.utc).isoformat(),
99
  "request_id": request_id,
100
  "method": method,
101
  "path": path,
102
  "client_ip_hash": _hash_sensitive(client_ip),
103
  "user_agent_hash": _hash_sensitive(user_agent),
104
  }
105
-
106
  # Try to read request body for POST requests (without logging PHI)
107
  if needs_audit and method == "POST":
108
  try:
@@ -116,35 +117,35 @@ class HIPAAAuditMiddleware(BaseHTTPMiddleware):
116
  # Log presence of biomarkers without values
117
  if "biomarkers" in body_dict:
118
  audit_entry["biomarker_count"] = len(body_dict["biomarkers"]) if isinstance(body_dict["biomarkers"], dict) else 1
119
- except Exception:
120
- pass
121
-
122
  if needs_audit:
123
  logger.info("AUDIT_REQUEST: %s", json.dumps(audit_entry))
124
-
125
  # Process request
126
  response: Response = await call_next(request)
127
-
128
  # Post-request audit
129
  elapsed_ms = (time.time() - start_time) * 1000
130
-
131
  completion_entry = {
132
  "event": "request_complete",
133
- "timestamp": datetime.now(timezone.utc).isoformat(),
134
  "request_id": request_id,
135
  "method": method,
136
  "path": path,
137
  "status_code": response.status_code,
138
  "elapsed_ms": round(elapsed_ms, 2),
139
  }
140
-
141
  if needs_audit:
142
  logger.info("AUDIT_COMPLETE: %s", json.dumps(completion_entry))
143
-
144
  # Add request ID to response headers
145
  response.headers["X-Request-ID"] = request_id
146
  response.headers["X-Response-Time"] = f"{elapsed_ms:.2f}ms"
147
-
148
  return response
149
 
150
 
@@ -152,10 +153,10 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
152
  """
153
  Add security headers for HIPAA compliance.
154
  """
155
-
156
  async def dispatch(self, request: Request, call_next: Callable) -> Response:
157
  response: Response = await call_next(request)
158
-
159
  # Security headers
160
  response.headers["X-Content-Type-Options"] = "nosniff"
161
  response.headers["X-Frame-Options"] = "DENY"
@@ -163,9 +164,9 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
163
  response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
164
  response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
165
  response.headers["Pragma"] = "no-cache"
166
-
167
  # Medical data should never be cached
168
  if any(ep in request.url.path for ep in AUDITABLE_ENDPOINTS):
169
  response.headers["Cache-Control"] = "no-store, private"
170
-
171
  return response
 
12
  import logging
13
  import time
14
  import uuid
15
+ from collections.abc import Callable
16
+ from datetime import UTC, datetime
17
+ from typing import Any
18
 
19
  from fastapi import Request, Response
20
  from starlette.middleware.base import BaseHTTPMiddleware
 
75
 
76
  Audit logs are structured JSON for easy SIEM integration.
77
  """
78
+
79
  async def dispatch(self, request: Request, call_next: Callable) -> Response:
80
  # Generate request ID
81
  request_id = f"req_{uuid.uuid4().hex[:12]}"
82
  request.state.request_id = request_id
83
+
84
  # Start timing
85
  start_time = time.time()
86
+
87
  # Extract metadata safely
88
  path = request.url.path
89
  method = request.method
90
  client_ip = request.client.host if request.client else "unknown"
91
  user_agent = request.headers.get("user-agent", "unknown")[:100]
92
+
93
  # Check if this endpoint needs audit logging
94
  needs_audit = any(path.startswith(ep) for ep in AUDITABLE_ENDPOINTS)
95
+
96
  # Pre-request audit entry
97
  audit_entry: dict[str, Any] = {
98
  "event": "request_start",
99
+ "timestamp": datetime.now(UTC).isoformat(),
100
  "request_id": request_id,
101
  "method": method,
102
  "path": path,
103
  "client_ip_hash": _hash_sensitive(client_ip),
104
  "user_agent_hash": _hash_sensitive(user_agent),
105
  }
106
+
107
  # Try to read request body for POST requests (without logging PHI)
108
  if needs_audit and method == "POST":
109
  try:
 
117
  # Log presence of biomarkers without values
118
  if "biomarkers" in body_dict:
119
  audit_entry["biomarker_count"] = len(body_dict["biomarkers"]) if isinstance(body_dict["biomarkers"], dict) else 1
120
+ except Exception as exc:
121
+ logger.debug("Failed to audit POST body: %s", exc)
122
+
123
  if needs_audit:
124
  logger.info("AUDIT_REQUEST: %s", json.dumps(audit_entry))
125
+
126
  # Process request
127
  response: Response = await call_next(request)
128
+
129
  # Post-request audit
130
  elapsed_ms = (time.time() - start_time) * 1000
131
+
132
  completion_entry = {
133
  "event": "request_complete",
134
+ "timestamp": datetime.now(UTC).isoformat(),
135
  "request_id": request_id,
136
  "method": method,
137
  "path": path,
138
  "status_code": response.status_code,
139
  "elapsed_ms": round(elapsed_ms, 2),
140
  }
141
+
142
  if needs_audit:
143
  logger.info("AUDIT_COMPLETE: %s", json.dumps(completion_entry))
144
+
145
  # Add request ID to response headers
146
  response.headers["X-Request-ID"] = request_id
147
  response.headers["X-Response-Time"] = f"{elapsed_ms:.2f}ms"
148
+
149
  return response
150
 
151
 
 
153
  """
154
  Add security headers for HIPAA compliance.
155
  """
156
+
157
  async def dispatch(self, request: Request, call_next: Callable) -> Response:
158
  response: Response = await call_next(request)
159
+
160
  # Security headers
161
  response.headers["X-Content-Type-Options"] = "nosniff"
162
  response.headers["X-Frame-Options"] = "DENY"
 
164
  response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
165
  response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
166
  response.headers["Pragma"] = "no-cache"
167
+
168
  # Medical data should never be cached
169
  if any(ep in request.url.path for ep in AUDITABLE_ENDPOINTS):
170
  response.headers["Cache-Control"] = "no-store, private"
171
+
172
  return response
src/pdf_processor.py CHANGED
@@ -6,13 +6,12 @@ PDF document processing and vector store creation
6
  import os
7
  import warnings
8
  from pathlib import Path
9
- from typing import List, Optional
10
- from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
11
- from langchain_text_splitters import RecursiveCharacterTextSplitter
12
  from langchain_community.vectorstores import FAISS
13
  from langchain_core.documents import Document
14
- from dotenv import load_dotenv
15
- import time
16
 
17
  # Suppress noisy warnings
18
  warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
@@ -22,12 +21,12 @@ os.environ.setdefault("HF_HUB_DISABLE_IMPLICIT_TOKEN", "1")
22
  load_dotenv()
23
 
24
  # Re-export for backward compatibility
25
- from src.llm_config import get_embedding_model # noqa: F401
26
 
27
 
28
  class PDFProcessor:
29
  """Handles medical PDF ingestion and vector store creation"""
30
-
31
  def __init__(
32
  self,
33
  pdf_directory: str = "data/medical_pdfs",
@@ -48,11 +47,11 @@ class PDFProcessor:
48
  self.vector_store_path = Path(vector_store_path)
49
  self.chunk_size = chunk_size
50
  self.chunk_overlap = chunk_overlap
51
-
52
  # Create directories if they don't exist
53
  self.pdf_directory.mkdir(parents=True, exist_ok=True)
54
  self.vector_store_path.mkdir(parents=True, exist_ok=True)
55
-
56
  # Text splitter with medical context awareness
57
  self.text_splitter = RecursiveCharacterTextSplitter(
58
  chunk_size=chunk_size,
@@ -60,8 +59,8 @@ class PDFProcessor:
60
  separators=["\n\n", "\n", ". ", " ", ""],
61
  length_function=len
62
  )
63
-
64
- def load_pdfs(self) -> List[Document]:
65
  """
66
  Load all PDF documents from the configured directory.
67
 
@@ -69,40 +68,40 @@ class PDFProcessor:
69
  List of Document objects with content and metadata
70
  """
71
  print(f"Loading PDFs from: {self.pdf_directory}")
72
-
73
  pdf_files = list(self.pdf_directory.glob("*.pdf"))
74
-
75
  if not pdf_files:
76
  print(f"WARN: No PDF files found in {self.pdf_directory}")
77
  print("INFO: Please place medical PDFs in this directory")
78
  return []
79
-
80
  print(f"Found {len(pdf_files)} PDF file(s):")
81
  for pdf in pdf_files:
82
  print(f" - {pdf.name}")
83
-
84
  documents = []
85
-
86
  for pdf_path in pdf_files:
87
  try:
88
  loader = PyPDFLoader(str(pdf_path))
89
  docs = loader.load()
90
-
91
  # Add source filename to metadata
92
  for doc in docs:
93
  doc.metadata['source_file'] = pdf_path.name
94
  doc.metadata['source_path'] = str(pdf_path)
95
-
96
  documents.extend(docs)
97
  print(f" OK: Loaded {len(docs)} pages from {pdf_path.name}")
98
-
99
  except Exception as e:
100
  print(f" ERROR: Error loading {pdf_path.name}: {e}")
101
-
102
  print(f"\nTotal: {len(documents)} pages loaded from {len(pdf_files)} PDF(s)")
103
  return documents
104
-
105
- def chunk_documents(self, documents: List[Document]) -> List[Document]:
106
  """
107
  Split documents into chunks for RAG retrieval.
108
 
@@ -113,25 +112,25 @@ class PDFProcessor:
113
  List of chunked documents with preserved metadata
114
  """
115
  print(f"\nChunking documents (size={self.chunk_size}, overlap={self.chunk_overlap})...")
116
-
117
  chunks = self.text_splitter.split_documents(documents)
118
-
119
  if not chunks:
120
  print("WARN: No chunks generated from documents")
121
  return chunks
122
-
123
  # Add chunk index to metadata
124
  for i, chunk in enumerate(chunks):
125
  chunk.metadata['chunk_id'] = i
126
-
127
  print(f"OK: Created {len(chunks)} chunks from {len(documents)} pages")
128
  print(f" Average chunk size: {sum(len(c.page_content) for c in chunks) // len(chunks)} characters")
129
-
130
  return chunks
131
-
132
  def create_vector_store(
133
  self,
134
- chunks: List[Document],
135
  embedding_model,
136
  store_name: str = "medical_knowledge"
137
  ) -> FAISS:
@@ -149,26 +148,26 @@ class PDFProcessor:
149
  print(f"\nCreating vector store: {store_name}")
150
  print(f"Generating embeddings for {len(chunks)} chunks...")
151
  print("(This may take a few minutes...)")
152
-
153
  # Create FAISS vector store
154
  vector_store = FAISS.from_documents(
155
  documents=chunks,
156
  embedding=embedding_model
157
  )
158
-
159
  # Save to disk
160
  save_path = self.vector_store_path / f"{store_name}.faiss"
161
  vector_store.save_local(str(self.vector_store_path), index_name=store_name)
162
-
163
  print(f"OK: Vector store created and saved to: {save_path}")
164
-
165
  return vector_store
166
-
167
  def load_vector_store(
168
  self,
169
  embedding_model,
170
  store_name: str = "medical_knowledge"
171
- ) -> Optional[FAISS]:
172
  """
173
  Load existing vector store from disk.
174
 
@@ -180,11 +179,11 @@ class PDFProcessor:
180
  FAISS vector store or None if not found
181
  """
182
  store_path = self.vector_store_path / f"{store_name}.faiss"
183
-
184
  if not store_path.exists():
185
  print(f"WARN: Vector store not found: {store_path}")
186
  return None
187
-
188
  try:
189
  # SECURITY NOTE: allow_dangerous_deserialization=True uses pickle.
190
  # Only load vector stores from trusted, locally-built sources.
@@ -197,11 +196,11 @@ class PDFProcessor:
197
  )
198
  print(f"OK: Loaded vector store from: {store_path}")
199
  return vector_store
200
-
201
  except Exception as e:
202
  print(f"ERROR: Error loading vector store: {e}")
203
  return None
204
-
205
  def create_retrievers(
206
  self,
207
  embedding_model,
@@ -224,19 +223,19 @@ class PDFProcessor:
224
  vector_store = self.load_vector_store(embedding_model, store_name)
225
  else:
226
  vector_store = None
227
-
228
  # If not found, create new one
229
  if vector_store is None:
230
  print("\nBuilding new vector store from PDFs...")
231
  documents = self.load_pdfs()
232
-
233
  if not documents:
234
  print("WARN: No documents to process. Please add PDF files.")
235
  return {}
236
-
237
  chunks = self.chunk_documents(documents)
238
  vector_store = self.create_vector_store(chunks, embedding_model, store_name)
239
-
240
  # Create specialized retrievers
241
  retrievers = {
242
  "disease_explainer": vector_store.as_retriever(
@@ -252,7 +251,7 @@ class PDFProcessor:
252
  search_kwargs={"k": 5}
253
  )
254
  }
255
-
256
  print(f"\nOK: Created {len(retrievers)} specialized retrievers")
257
  return retrievers
258
 
@@ -272,28 +271,28 @@ def setup_knowledge_base(embedding_model=None, force_rebuild: bool = False, use_
272
  print("=" * 60)
273
  print("Setting up Medical Knowledge Base")
274
  print("=" * 60)
275
-
276
  # Use configured embedding provider from environment
277
  if use_configured_embeddings and embedding_model is None:
278
  embedding_model = get_embedding_model()
279
  print(" > Embeddings model loaded")
280
  elif embedding_model is None:
281
  raise ValueError("Must provide embedding_model or set use_configured_embeddings=True")
282
-
283
  processor = PDFProcessor()
284
  retrievers = processor.create_retrievers(
285
  embedding_model,
286
  store_name="medical_knowledge",
287
  force_rebuild=force_rebuild
288
  )
289
-
290
  if retrievers:
291
  print("\nOK: Knowledge base setup complete!")
292
  else:
293
  print("\nWARN: Knowledge base setup incomplete. Add PDFs and try again.")
294
-
295
  print("=" * 60)
296
-
297
  return retrievers
298
 
299
 
@@ -320,22 +319,22 @@ if __name__ == "__main__":
320
  # Test PDF processing
321
  import sys
322
  from pathlib import Path
323
-
324
  # Add parent directory to path for imports
325
  sys.path.insert(0, str(Path(__file__).parent.parent))
326
-
327
  print("\n" + "="*70)
328
  print("MediGuard AI - PDF Knowledge Base Builder")
329
  print("="*70)
330
  print("\nUsing configured embedding provider from .env")
331
  print(" EMBEDDING_PROVIDER options: google (default), huggingface, ollama")
332
  print("="*70)
333
-
334
  retrievers = setup_knowledge_base(
335
  use_configured_embeddings=True, # Use configured provider
336
  force_rebuild=False
337
  )
338
-
339
  if retrievers:
340
  print("\nOK: PDF processing test successful!")
341
  print(f"Available retrievers: {list(retrievers.keys())}")
 
6
  import os
7
  import warnings
8
  from pathlib import Path
9
+
10
+ from dotenv import load_dotenv
11
+ from langchain_community.document_loaders import PyPDFLoader
12
  from langchain_community.vectorstores import FAISS
13
  from langchain_core.documents import Document
14
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
 
15
 
16
  # Suppress noisy warnings
17
  warnings.filterwarnings("ignore", message=".*class.*HuggingFaceEmbeddings.*was deprecated.*")
 
21
  load_dotenv()
22
 
23
  # Re-export for backward compatibility
24
+ from src.llm_config import get_embedding_model
25
 
26
 
27
  class PDFProcessor:
28
  """Handles medical PDF ingestion and vector store creation"""
29
+
30
  def __init__(
31
  self,
32
  pdf_directory: str = "data/medical_pdfs",
 
47
  self.vector_store_path = Path(vector_store_path)
48
  self.chunk_size = chunk_size
49
  self.chunk_overlap = chunk_overlap
50
+
51
  # Create directories if they don't exist
52
  self.pdf_directory.mkdir(parents=True, exist_ok=True)
53
  self.vector_store_path.mkdir(parents=True, exist_ok=True)
54
+
55
  # Text splitter with medical context awareness
56
  self.text_splitter = RecursiveCharacterTextSplitter(
57
  chunk_size=chunk_size,
 
59
  separators=["\n\n", "\n", ". ", " ", ""],
60
  length_function=len
61
  )
62
+
63
+ def load_pdfs(self) -> list[Document]:
64
  """
65
  Load all PDF documents from the configured directory.
66
 
 
68
  List of Document objects with content and metadata
69
  """
70
  print(f"Loading PDFs from: {self.pdf_directory}")
71
+
72
  pdf_files = list(self.pdf_directory.glob("*.pdf"))
73
+
74
  if not pdf_files:
75
  print(f"WARN: No PDF files found in {self.pdf_directory}")
76
  print("INFO: Please place medical PDFs in this directory")
77
  return []
78
+
79
  print(f"Found {len(pdf_files)} PDF file(s):")
80
  for pdf in pdf_files:
81
  print(f" - {pdf.name}")
82
+
83
  documents = []
84
+
85
  for pdf_path in pdf_files:
86
  try:
87
  loader = PyPDFLoader(str(pdf_path))
88
  docs = loader.load()
89
+
90
  # Add source filename to metadata
91
  for doc in docs:
92
  doc.metadata['source_file'] = pdf_path.name
93
  doc.metadata['source_path'] = str(pdf_path)
94
+
95
  documents.extend(docs)
96
  print(f" OK: Loaded {len(docs)} pages from {pdf_path.name}")
97
+
98
  except Exception as e:
99
  print(f" ERROR: Error loading {pdf_path.name}: {e}")
100
+
101
  print(f"\nTotal: {len(documents)} pages loaded from {len(pdf_files)} PDF(s)")
102
  return documents
103
+
104
+ def chunk_documents(self, documents: list[Document]) -> list[Document]:
105
  """
106
  Split documents into chunks for RAG retrieval.
107
 
 
112
  List of chunked documents with preserved metadata
113
  """
114
  print(f"\nChunking documents (size={self.chunk_size}, overlap={self.chunk_overlap})...")
115
+
116
  chunks = self.text_splitter.split_documents(documents)
117
+
118
  if not chunks:
119
  print("WARN: No chunks generated from documents")
120
  return chunks
121
+
122
  # Add chunk index to metadata
123
  for i, chunk in enumerate(chunks):
124
  chunk.metadata['chunk_id'] = i
125
+
126
  print(f"OK: Created {len(chunks)} chunks from {len(documents)} pages")
127
  print(f" Average chunk size: {sum(len(c.page_content) for c in chunks) // len(chunks)} characters")
128
+
129
  return chunks
130
+
131
  def create_vector_store(
132
  self,
133
+ chunks: list[Document],
134
  embedding_model,
135
  store_name: str = "medical_knowledge"
136
  ) -> FAISS:
 
148
  print(f"\nCreating vector store: {store_name}")
149
  print(f"Generating embeddings for {len(chunks)} chunks...")
150
  print("(This may take a few minutes...)")
151
+
152
  # Create FAISS vector store
153
  vector_store = FAISS.from_documents(
154
  documents=chunks,
155
  embedding=embedding_model
156
  )
157
+
158
  # Save to disk
159
  save_path = self.vector_store_path / f"{store_name}.faiss"
160
  vector_store.save_local(str(self.vector_store_path), index_name=store_name)
161
+
162
  print(f"OK: Vector store created and saved to: {save_path}")
163
+
164
  return vector_store
165
+
166
  def load_vector_store(
167
  self,
168
  embedding_model,
169
  store_name: str = "medical_knowledge"
170
+ ) -> FAISS | None:
171
  """
172
  Load existing vector store from disk.
173
 
 
179
  FAISS vector store or None if not found
180
  """
181
  store_path = self.vector_store_path / f"{store_name}.faiss"
182
+
183
  if not store_path.exists():
184
  print(f"WARN: Vector store not found: {store_path}")
185
  return None
186
+
187
  try:
188
  # SECURITY NOTE: allow_dangerous_deserialization=True uses pickle.
189
  # Only load vector stores from trusted, locally-built sources.
 
196
  )
197
  print(f"OK: Loaded vector store from: {store_path}")
198
  return vector_store
199
+
200
  except Exception as e:
201
  print(f"ERROR: Error loading vector store: {e}")
202
  return None
203
+
204
  def create_retrievers(
205
  self,
206
  embedding_model,
 
223
  vector_store = self.load_vector_store(embedding_model, store_name)
224
  else:
225
  vector_store = None
226
+
227
  # If not found, create new one
228
  if vector_store is None:
229
  print("\nBuilding new vector store from PDFs...")
230
  documents = self.load_pdfs()
231
+
232
  if not documents:
233
  print("WARN: No documents to process. Please add PDF files.")
234
  return {}
235
+
236
  chunks = self.chunk_documents(documents)
237
  vector_store = self.create_vector_store(chunks, embedding_model, store_name)
238
+
239
  # Create specialized retrievers
240
  retrievers = {
241
  "disease_explainer": vector_store.as_retriever(
 
251
  search_kwargs={"k": 5}
252
  )
253
  }
254
+
255
  print(f"\nOK: Created {len(retrievers)} specialized retrievers")
256
  return retrievers
257
 
 
271
  print("=" * 60)
272
  print("Setting up Medical Knowledge Base")
273
  print("=" * 60)
274
+
275
  # Use configured embedding provider from environment
276
  if use_configured_embeddings and embedding_model is None:
277
  embedding_model = get_embedding_model()
278
  print(" > Embeddings model loaded")
279
  elif embedding_model is None:
280
  raise ValueError("Must provide embedding_model or set use_configured_embeddings=True")
281
+
282
  processor = PDFProcessor()
283
  retrievers = processor.create_retrievers(
284
  embedding_model,
285
  store_name="medical_knowledge",
286
  force_rebuild=force_rebuild
287
  )
288
+
289
  if retrievers:
290
  print("\nOK: Knowledge base setup complete!")
291
  else:
292
  print("\nWARN: Knowledge base setup incomplete. Add PDFs and try again.")
293
+
294
  print("=" * 60)
295
+
296
  return retrievers
297
 
298
 
 
319
  # Test PDF processing
320
  import sys
321
  from pathlib import Path
322
+
323
  # Add parent directory to path for imports
324
  sys.path.insert(0, str(Path(__file__).parent.parent))
325
+
326
  print("\n" + "="*70)
327
  print("MediGuard AI - PDF Knowledge Base Builder")
328
  print("="*70)
329
  print("\nUsing configured embedding provider from .env")
330
  print(" EMBEDDING_PROVIDER options: google (default), huggingface, ollama")
331
  print("="*70)
332
+
333
  retrievers = setup_knowledge_base(
334
  use_configured_embeddings=True, # Use configured provider
335
  force_rebuild=False
336
  )
337
+
338
  if retrievers:
339
  print("\nOK: PDF processing test successful!")
340
  print(f"Available retrievers: {list(retrievers.keys())}")
src/repositories/analysis.py CHANGED
@@ -4,8 +4,6 @@ MediGuard AI — Analysis repository (data-access layer).
4
 
5
  from __future__ import annotations
6
 
7
- from typing import List, Optional
8
-
9
  from sqlalchemy.orm import Session
10
 
11
  from src.models.analysis import PatientAnalysis
@@ -22,14 +20,14 @@ class AnalysisRepository:
22
  self.db.flush()
23
  return analysis
24
 
25
- def get_by_request_id(self, request_id: str) -> Optional[PatientAnalysis]:
26
  return (
27
  self.db.query(PatientAnalysis)
28
  .filter(PatientAnalysis.request_id == request_id)
29
  .first()
30
  )
31
 
32
- def list_recent(self, limit: int = 20) -> List[PatientAnalysis]:
33
  return (
34
  self.db.query(PatientAnalysis)
35
  .order_by(PatientAnalysis.created_at.desc())
 
4
 
5
  from __future__ import annotations
6
 
 
 
7
  from sqlalchemy.orm import Session
8
 
9
  from src.models.analysis import PatientAnalysis
 
20
  self.db.flush()
21
  return analysis
22
 
23
+ def get_by_request_id(self, request_id: str) -> PatientAnalysis | None:
24
  return (
25
  self.db.query(PatientAnalysis)
26
  .filter(PatientAnalysis.request_id == request_id)
27
  .first()
28
  )
29
 
30
+ def list_recent(self, limit: int = 20) -> list[PatientAnalysis]:
31
  return (
32
  self.db.query(PatientAnalysis)
33
  .order_by(PatientAnalysis.created_at.desc())
src/repositories/document.py CHANGED
@@ -4,8 +4,6 @@ MediGuard AI — Document repository.
4
 
5
  from __future__ import annotations
6
 
7
- from typing import List, Optional
8
-
9
  from sqlalchemy.orm import Session
10
 
11
  from src.models.analysis import MedicalDocument
@@ -33,10 +31,10 @@ class DocumentRepository:
33
  self.db.flush()
34
  return doc
35
 
36
- def get_by_id(self, doc_id: str) -> Optional[MedicalDocument]:
37
  return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first()
38
 
39
- def list_all(self, limit: int = 100) -> List[MedicalDocument]:
40
  return (
41
  self.db.query(MedicalDocument)
42
  .order_by(MedicalDocument.created_at.desc())
 
4
 
5
  from __future__ import annotations
6
 
 
 
7
  from sqlalchemy.orm import Session
8
 
9
  from src.models.analysis import MedicalDocument
 
31
  self.db.flush()
32
  return doc
33
 
34
+ def get_by_id(self, doc_id: str) -> MedicalDocument | None:
35
  return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first()
36
 
37
+ def list_all(self, limit: int = 100) -> list[MedicalDocument]:
38
  return (
39
  self.db.query(MedicalDocument)
40
  .order_by(MedicalDocument.created_at.desc())
src/routers/analyze.py CHANGED
@@ -12,8 +12,8 @@ import logging
12
  import time
13
  import uuid
14
  from concurrent.futures import ThreadPoolExecutor
15
- from datetime import datetime, timezone
16
- from typing import Any, Dict
17
 
18
  from fastapi import APIRouter, HTTPException, Request
19
 
@@ -30,7 +30,7 @@ router = APIRouter(prefix="/analyze", tags=["analysis"])
30
  _executor = ThreadPoolExecutor(max_workers=4)
31
 
32
 
33
- def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
34
  """Rule-based disease scoring (NOT ML prediction)."""
35
  scores = {
36
  "Diabetes": 0.0,
@@ -39,7 +39,7 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
39
  "Thrombocytopenia": 0.0,
40
  "Thalassemia": 0.0
41
  }
42
-
43
  # Diabetes indicators
44
  glucose = biomarkers.get("Glucose")
45
  hba1c = biomarkers.get("HbA1c")
@@ -49,7 +49,7 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
49
  scores["Diabetes"] += 0.2
50
  if hba1c is not None and hba1c >= 6.5:
51
  scores["Diabetes"] += 0.5
52
-
53
  # Anemia indicators
54
  hemoglobin = biomarkers.get("Hemoglobin")
55
  mcv = biomarkers.get("Mean Corpuscular Volume", biomarkers.get("MCV"))
@@ -59,7 +59,7 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
59
  scores["Anemia"] += 0.2
60
  if mcv is not None and mcv < 80:
61
  scores["Anemia"] += 0.2
62
-
63
  # Heart disease indicators
64
  cholesterol = biomarkers.get("Cholesterol")
65
  troponin = biomarkers.get("Troponin")
@@ -70,32 +70,32 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
70
  scores["Heart Disease"] += 0.6
71
  if ldl is not None and ldl > 190:
72
  scores["Heart Disease"] += 0.2
73
-
74
  # Thrombocytopenia indicators
75
  platelets = biomarkers.get("Platelets")
76
  if platelets is not None and platelets < 150000:
77
  scores["Thrombocytopenia"] += 0.6
78
  if platelets is not None and platelets < 50000:
79
  scores["Thrombocytopenia"] += 0.3
80
-
81
  # Thalassemia indicators
82
  if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
83
  scores["Thalassemia"] += 0.4
84
-
85
  # Find top prediction
86
  top_disease = max(scores, key=scores.get)
87
  confidence = min(scores[top_disease], 1.0)
88
-
89
  if confidence == 0.0:
90
  top_disease = "Undetermined"
91
-
92
  # Normalize probabilities
93
  total = sum(scores.values())
94
  if total > 0:
95
  probabilities = {k: v / total for k, v in scores.items()}
96
  else:
97
  probabilities = {k: 1.0 / len(scores) for k in scores}
98
-
99
  return {
100
  "disease": top_disease,
101
  "confidence": confidence,
@@ -105,16 +105,16 @@ def _score_disease_heuristic(biomarkers: Dict[str, float]) -> Dict[str, Any]:
105
 
106
  async def _run_guild_analysis(
107
  request: Request,
108
- biomarkers: Dict[str, float],
109
- patient_ctx: Dict[str, Any],
110
- extracted_biomarkers: Dict[str, float] | None = None,
111
  ) -> AnalysisResponse:
112
  """Execute the ClinicalInsightGuild and build the response envelope."""
113
  request_id = f"req_{uuid.uuid4().hex[:12]}"
114
  t0 = time.time()
115
 
116
  ragbot = getattr(request.app.state, "ragbot_service", None)
117
- if ragbot is None or not ragbot.is_ready():
118
  raise HTTPException(status_code=503, detail="Analysis service unavailable. Please wait for initialization.")
119
 
120
  # Generate disease prediction
@@ -122,15 +122,16 @@ async def _run_guild_analysis(
122
 
123
  try:
124
  # Run sync function in thread pool
 
 
 
 
 
 
125
  loop = asyncio.get_running_loop()
126
  result = await loop.run_in_executor(
127
  _executor,
128
- lambda: ragbot.analyze(
129
- biomarkers=biomarkers,
130
- patient_context=patient_ctx,
131
- model_prediction=model_prediction,
132
- extracted_biomarkers=extracted_biomarkers
133
- )
134
  )
135
  except Exception as exc:
136
  logger.exception("Guild analysis failed: %s", exc)
@@ -142,20 +143,15 @@ async def _run_guild_analysis(
142
  elapsed = (time.time() - t0) * 1000
143
 
144
  # Build response from result
145
- # Guild workflow returns a dict; ragbot.analyze() may return dict or object
146
- if isinstance(result, dict):
147
- prediction = result.get('prediction')
148
- analysis = result.get('analysis')
149
- conversational_summary = result.get('conversational_summary')
150
- else:
151
- prediction = getattr(result, 'prediction', None)
152
- analysis = getattr(result, 'analysis', None)
153
- conversational_summary = getattr(result, 'conversational_summary', None)
154
 
155
  return AnalysisResponse(
156
  status="success",
157
  request_id=request_id,
158
- timestamp=datetime.now(timezone.utc).isoformat(),
159
  extracted_biomarkers=extracted_biomarkers,
160
  input_biomarkers=biomarkers,
161
  patient_context=patient_ctx,
 
12
  import time
13
  import uuid
14
  from concurrent.futures import ThreadPoolExecutor
15
+ from datetime import UTC, datetime
16
+ from typing import Any
17
 
18
  from fastapi import APIRouter, HTTPException, Request
19
 
 
30
  _executor = ThreadPoolExecutor(max_workers=4)
31
 
32
 
33
+ def _score_disease_heuristic(biomarkers: dict[str, float]) -> dict[str, Any]:
34
  """Rule-based disease scoring (NOT ML prediction)."""
35
  scores = {
36
  "Diabetes": 0.0,
 
39
  "Thrombocytopenia": 0.0,
40
  "Thalassemia": 0.0
41
  }
42
+
43
  # Diabetes indicators
44
  glucose = biomarkers.get("Glucose")
45
  hba1c = biomarkers.get("HbA1c")
 
49
  scores["Diabetes"] += 0.2
50
  if hba1c is not None and hba1c >= 6.5:
51
  scores["Diabetes"] += 0.5
52
+
53
  # Anemia indicators
54
  hemoglobin = biomarkers.get("Hemoglobin")
55
  mcv = biomarkers.get("Mean Corpuscular Volume", biomarkers.get("MCV"))
 
59
  scores["Anemia"] += 0.2
60
  if mcv is not None and mcv < 80:
61
  scores["Anemia"] += 0.2
62
+
63
  # Heart disease indicators
64
  cholesterol = biomarkers.get("Cholesterol")
65
  troponin = biomarkers.get("Troponin")
 
70
  scores["Heart Disease"] += 0.6
71
  if ldl is not None and ldl > 190:
72
  scores["Heart Disease"] += 0.2
73
+
74
  # Thrombocytopenia indicators
75
  platelets = biomarkers.get("Platelets")
76
  if platelets is not None and platelets < 150000:
77
  scores["Thrombocytopenia"] += 0.6
78
  if platelets is not None and platelets < 50000:
79
  scores["Thrombocytopenia"] += 0.3
80
+
81
  # Thalassemia indicators
82
  if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
83
  scores["Thalassemia"] += 0.4
84
+
85
  # Find top prediction
86
  top_disease = max(scores, key=scores.get)
87
  confidence = min(scores[top_disease], 1.0)
88
+
89
  if confidence == 0.0:
90
  top_disease = "Undetermined"
91
+
92
  # Normalize probabilities
93
  total = sum(scores.values())
94
  if total > 0:
95
  probabilities = {k: v / total for k, v in scores.items()}
96
  else:
97
  probabilities = {k: 1.0 / len(scores) for k in scores}
98
+
99
  return {
100
  "disease": top_disease,
101
  "confidence": confidence,
 
105
 
106
  async def _run_guild_analysis(
107
  request: Request,
108
+ biomarkers: dict[str, float],
109
+ patient_ctx: dict[str, Any],
110
+ extracted_biomarkers: dict[str, float] | None = None,
111
  ) -> AnalysisResponse:
112
  """Execute the ClinicalInsightGuild and build the response envelope."""
113
  request_id = f"req_{uuid.uuid4().hex[:12]}"
114
  t0 = time.time()
115
 
116
  ragbot = getattr(request.app.state, "ragbot_service", None)
117
+ if ragbot is None:
118
  raise HTTPException(status_code=503, detail="Analysis service unavailable. Please wait for initialization.")
119
 
120
  # Generate disease prediction
 
122
 
123
  try:
124
  # Run sync function in thread pool
125
+ from src.state import PatientInput
126
+ patient_input = PatientInput(
127
+ biomarkers=biomarkers,
128
+ patient_context=patient_ctx,
129
+ model_prediction=model_prediction
130
+ )
131
  loop = asyncio.get_running_loop()
132
  result = await loop.run_in_executor(
133
  _executor,
134
+ lambda: ragbot.run(patient_input)
 
 
 
 
 
135
  )
136
  except Exception as exc:
137
  logger.exception("Guild analysis failed: %s", exc)
 
143
  elapsed = (time.time() - t0) * 1000
144
 
145
  # Build response from result
146
+ prediction = result.get('model_prediction')
147
+ analysis = result.get('final_response', {})
148
+ # Try to extract the conversational_summary if it's there
149
+ conversational_summary = analysis.get('conversational_summary') if isinstance(analysis, dict) else str(analysis)
 
 
 
 
 
150
 
151
  return AnalysisResponse(
152
  status="success",
153
  request_id=request_id,
154
+ timestamp=datetime.now(UTC).isoformat(),
155
  extracted_biomarkers=extracted_biomarkers,
156
  input_biomarkers=biomarkers,
157
  patient_context=patient_ctx,
src/routers/ask.py CHANGED
@@ -12,13 +12,12 @@ import json
12
  import logging
13
  import time
14
  import uuid
15
- from datetime import datetime, timezone
16
- from typing import AsyncGenerator
17
 
18
  from fastapi import APIRouter, HTTPException, Request
19
  from fastapi.responses import StreamingResponse
20
 
21
- from src.schemas.schemas import AskRequest, AskResponse
22
 
23
  logger = logging.getLogger(__name__)
24
  router = APIRouter(tags=["ask"])
@@ -81,12 +80,12 @@ async def _stream_rag_response(
81
  - error: Error information
82
  """
83
  t0 = time.time()
84
-
85
  try:
86
  # Send initial status
87
  yield f"event: status\ndata: {json.dumps({'stage': 'guardrail', 'message': 'Validating query...'})}\n\n"
88
  await asyncio.sleep(0) # Allow event loop to flush
89
-
90
  # Run the RAG pipeline (synchronous, but we yield progress)
91
  loop = asyncio.get_running_loop()
92
  result = await loop.run_in_executor(
@@ -97,16 +96,16 @@ async def _stream_rag_response(
97
  patient_context=patient_context,
98
  )
99
  )
100
-
101
  # Send retrieval metadata
102
  yield f"event: metadata\ndata: {json.dumps({'documents_retrieved': len(result.get('retrieved_documents', [])), 'documents_relevant': len(result.get('relevant_documents', [])), 'guardrail_score': result.get('guardrail_score')})}\n\n"
103
  await asyncio.sleep(0)
104
-
105
  # Stream the answer token by token for smooth UI
106
  answer = result.get("final_answer", "")
107
  if answer:
108
  yield f"event: status\ndata: {json.dumps({'stage': 'generating', 'message': 'Generating response...'})}\n\n"
109
-
110
  # Simulate streaming by chunking the response
111
  words = answer.split()
112
  chunk_size = 3 # Send 3 words at a time
@@ -116,11 +115,11 @@ async def _stream_rag_response(
116
  chunk += " "
117
  yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
118
  await asyncio.sleep(0.02) # Small delay for visual streaming effect
119
-
120
  # Send completion
121
  elapsed = (time.time() - t0) * 1000
122
  yield f"event: done\ndata: {json.dumps({'request_id': request_id, 'processing_time_ms': round(elapsed, 1), 'status': 'success'})}\n\n"
123
-
124
  except Exception as exc:
125
  logger.exception("Streaming RAG failed: %s", exc)
126
  yield f"event: error\ndata: {json.dumps({'error': str(exc), 'request_id': request_id})}\n\n"
@@ -154,9 +153,9 @@ async def ask_medical_question_stream(body: AskRequest, request: Request):
154
  rag_service = getattr(request.app.state, "rag_service", None)
155
  if rag_service is None:
156
  raise HTTPException(status_code=503, detail="RAG service unavailable")
157
-
158
  request_id = f"req_{uuid.uuid4().hex[:12]}"
159
-
160
  return StreamingResponse(
161
  _stream_rag_response(
162
  rag_service,
@@ -172,3 +171,17 @@ async def ask_medical_question_stream(body: AskRequest, request: Request):
172
  "X-Request-ID": request_id,
173
  },
174
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import logging
13
  import time
14
  import uuid
15
+ from collections.abc import AsyncGenerator
 
16
 
17
  from fastapi import APIRouter, HTTPException, Request
18
  from fastapi.responses import StreamingResponse
19
 
20
+ from src.schemas.schemas import AskRequest, AskResponse, FeedbackRequest, FeedbackResponse
21
 
22
  logger = logging.getLogger(__name__)
23
  router = APIRouter(tags=["ask"])
 
80
  - error: Error information
81
  """
82
  t0 = time.time()
83
+
84
  try:
85
  # Send initial status
86
  yield f"event: status\ndata: {json.dumps({'stage': 'guardrail', 'message': 'Validating query...'})}\n\n"
87
  await asyncio.sleep(0) # Allow event loop to flush
88
+
89
  # Run the RAG pipeline (synchronous, but we yield progress)
90
  loop = asyncio.get_running_loop()
91
  result = await loop.run_in_executor(
 
96
  patient_context=patient_context,
97
  )
98
  )
99
+
100
  # Send retrieval metadata
101
  yield f"event: metadata\ndata: {json.dumps({'documents_retrieved': len(result.get('retrieved_documents', [])), 'documents_relevant': len(result.get('relevant_documents', [])), 'guardrail_score': result.get('guardrail_score')})}\n\n"
102
  await asyncio.sleep(0)
103
+
104
  # Stream the answer token by token for smooth UI
105
  answer = result.get("final_answer", "")
106
  if answer:
107
  yield f"event: status\ndata: {json.dumps({'stage': 'generating', 'message': 'Generating response...'})}\n\n"
108
+
109
  # Simulate streaming by chunking the response
110
  words = answer.split()
111
  chunk_size = 3 # Send 3 words at a time
 
115
  chunk += " "
116
  yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
117
  await asyncio.sleep(0.02) # Small delay for visual streaming effect
118
+
119
  # Send completion
120
  elapsed = (time.time() - t0) * 1000
121
  yield f"event: done\ndata: {json.dumps({'request_id': request_id, 'processing_time_ms': round(elapsed, 1), 'status': 'success'})}\n\n"
122
+
123
  except Exception as exc:
124
  logger.exception("Streaming RAG failed: %s", exc)
125
  yield f"event: error\ndata: {json.dumps({'error': str(exc), 'request_id': request_id})}\n\n"
 
153
  rag_service = getattr(request.app.state, "rag_service", None)
154
  if rag_service is None:
155
  raise HTTPException(status_code=503, detail="RAG service unavailable")
156
+
157
  request_id = f"req_{uuid.uuid4().hex[:12]}"
158
+
159
  return StreamingResponse(
160
  _stream_rag_response(
161
  rag_service,
 
171
  "X-Request-ID": request_id,
172
  },
173
  )
174
+
175
+
176
+ @router.post("/feedback", response_model=FeedbackResponse)
177
+ async def submit_feedback(body: FeedbackRequest, request: Request):
178
+ """Submit user feedback for an analysis or RAG response."""
179
+ tracer = getattr(request.app.state, "tracer", None)
180
+ if tracer:
181
+ tracer.score(
182
+ trace_id=body.request_id,
183
+ name="user-feedback",
184
+ value=body.score,
185
+ comment=body.comment
186
+ )
187
+ return FeedbackResponse(request_id=body.request_id)
src/routers/health.py CHANGED
@@ -7,7 +7,7 @@ Provides /health and /health/ready with per-service checks.
7
  from __future__ import annotations
8
 
9
  import time
10
- from datetime import datetime, timezone
11
 
12
  from fastapi import APIRouter, Request
13
 
@@ -23,7 +23,7 @@ async def health_check(request: Request) -> HealthResponse:
23
  uptime = time.time() - getattr(app_state, "start_time", time.time())
24
  return HealthResponse(
25
  status="healthy",
26
- timestamp=datetime.now(timezone.utc).isoformat(),
27
  version=getattr(app_state, "version", "2.0.0"),
28
  uptime_seconds=round(uptime, 2),
29
  )
@@ -39,9 +39,10 @@ async def readiness_check(request: Request) -> HealthResponse:
39
 
40
  # --- PostgreSQL ---
41
  try:
42
- from src.database import get_engine
43
  from sqlalchemy import text
44
- engine = get_engine()
 
 
45
  if engine is not None:
46
  t0 = time.time()
47
  with engine.connect() as conn:
@@ -86,9 +87,10 @@ async def readiness_check(request: Request) -> HealthResponse:
86
  ollama = getattr(app_state, "ollama_client", None)
87
  if ollama is not None:
88
  t0 = time.time()
89
- healthy = ollama.health()
90
  latency = (time.time() - t0) * 1000
91
- services.append(ServiceHealth(name="ollama", status="ok" if healthy else "degraded", latency_ms=round(latency, 1)))
 
92
  else:
93
  services.append(ServiceHealth(name="ollama", status="unavailable"))
94
  except Exception as exc:
@@ -126,7 +128,7 @@ async def readiness_check(request: Request) -> HealthResponse:
126
 
127
  return HealthResponse(
128
  status=overall,
129
- timestamp=datetime.now(timezone.utc).isoformat(),
130
  version=getattr(app_state, "version", "2.0.0"),
131
  uptime_seconds=round(uptime, 2),
132
  services=services,
 
7
  from __future__ import annotations
8
 
9
  import time
10
+ from datetime import UTC, datetime
11
 
12
  from fastapi import APIRouter, Request
13
 
 
23
  uptime = time.time() - getattr(app_state, "start_time", time.time())
24
  return HealthResponse(
25
  status="healthy",
26
+ timestamp=datetime.now(UTC).isoformat(),
27
  version=getattr(app_state, "version", "2.0.0"),
28
  uptime_seconds=round(uptime, 2),
29
  )
 
39
 
40
  # --- PostgreSQL ---
41
  try:
 
42
  from sqlalchemy import text
43
+
44
+ from src.database import _engine
45
+ engine = _engine()
46
  if engine is not None:
47
  t0 = time.time()
48
  with engine.connect() as conn:
 
87
  ollama = getattr(app_state, "ollama_client", None)
88
  if ollama is not None:
89
  t0 = time.time()
90
+ health_info = ollama.health()
91
  latency = (time.time() - t0) * 1000
92
+ is_healthy = isinstance(health_info, dict) and health_info.get("status") == "ok"
93
+ services.append(ServiceHealth(name="ollama", status="ok" if is_healthy else "degraded", latency_ms=round(latency, 1)))
94
  else:
95
  services.append(ServiceHealth(name="ollama", status="unavailable"))
96
  except Exception as exc:
 
128
 
129
  return HealthResponse(
130
  status=overall,
131
+ timestamp=datetime.now(UTC).isoformat(),
132
  version=getattr(app_state, "version", "2.0.0"),
133
  uptime_seconds=round(uptime, 2),
134
  services=services,