diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..12526469e3740b0bf0f3ac21961879c18e4c4f6c --- /dev/null +++ b/.env.example @@ -0,0 +1,61 @@ +# =========================================================================== +# MediGuard AI — Environment Variables +# =========================================================================== +# Copy this file to .env and fill in your values. +# =========================================================================== + +# --- API --- +API__HOST=0.0.0.0 +API__PORT=8000 +API__DEBUG=true +CORS_ALLOWED_ORIGINS=* + +# --- PostgreSQL --- +POSTGRES__HOST=localhost +POSTGRES__PORT=5432 +POSTGRES__DATABASE=mediguard +POSTGRES__USER=mediguard +POSTGRES__PASSWORD=mediguard_secret + +# --- OpenSearch --- +OPENSEARCH__HOST=localhost +OPENSEARCH__PORT=9200 + +# --- Redis --- +REDIS__HOST=localhost +REDIS__PORT=6379 +REDIS__ENABLED=true + +# --- Ollama --- +OLLAMA__BASE_URL=http://localhost:11434 +OLLAMA__MODEL=llama3.2 + +# --- LLM (Groq / Gemini — existing providers) --- +LLM__PRIMARY_PROVIDER=groq +LLM__GROQ_API_KEY=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 +LLM__GROQ_MODEL=llama-3.3-70b-versatile +LLM__GEMINI_API_KEY=AIzaSyBbWG-vy44GXuZL-PgNjtvKLXrhdINCgwg +LLM__GEMINI_MODEL=gemini-2.0-flash + +# --- Embeddings --- +EMBEDDING__PROVIDER=jina +EMBEDDING__JINA_API_KEY= +EMBEDDING__MODEL_NAME=jina-embeddings-v3 +EMBEDDING__DIMENSION=1024 + +# --- Langfuse --- +LANGFUSE__ENABLED=true +LANGFUSE__PUBLIC_KEY= +LANGFUSE__SECRET_KEY= +LANGFUSE__HOST=http://localhost:3000 + +# --- Chunking --- +CHUNKING__CHUNK_SIZE=1024 +CHUNKING__CHUNK_OVERLAP=128 + +# --- Telegram Bot (optional) --- +TELEGRAM__BOT_TOKEN= +TELEGRAM__API_BASE_URL=http://localhost:8000 + +# --- Medical PDFs --- +MEDICAL_PDFS__DIRECTORY=data/medical_pdfs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a2e852147524cfe46cfa5895d9abf8686701822 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +# MediGuard AI — Pre-commit hooks +# Install: pre-commit install +# Run all: pre-commit run --all-files + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-json + - id: check-merge-conflict + - id: detect-private-key + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.0 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.12.0 + hooks: + - id: mypy + additional_dependencies: [pydantic>=2.0] + args: [--ignore-missing-imports] diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..9f3d79001e1514882d830166bcea8b0dcc31f710 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,66 @@ +# =========================================================================== +# MediGuard AI — Multi-stage Dockerfile +# =========================================================================== +# Build stages: +# base — Python + system deps +# production — slim runtime image +# =========================================================================== + +# --------------------------------------------------------------------------- +# Stage 1: base +# --------------------------------------------------------------------------- +FROM python:3.11-slim AS base + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 + +WORKDIR /app + +# System dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY pyproject.toml ./ +RUN pip install --upgrade pip && \ + pip install ".[all]" + +# --------------------------------------------------------------------------- +# Stage 2: production +# --------------------------------------------------------------------------- +FROM python:3.11-slim AS production + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +WORKDIR /app + +# Copy installed packages from base +COPY --from=base /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages +COPY --from=base /usr/local/bin /usr/local/bin + +# Copy application code +COPY . . + +# Runtime dependencies only +RUN apt-get update && \ + apt-get install -y --no-install-recommends curl && \ + rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN groupadd -r mediguard && \ + useradd -r -g mediguard -d /app -s /sbin/nologin mediguard && \ + chown -R mediguard:mediguard /app + +USER mediguard + +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=5s --retries=3 \ + CMD curl -sf http://localhost:8000/health || exit 1 + +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..631f2ac0d6e59671117bf5ca337034456bc35977 --- /dev/null +++ b/Makefile @@ -0,0 +1,137 @@ +# =========================================================================== +# MediGuard AI — Makefile +# =========================================================================== +# Usage: +# make help — show all targets +# make setup — install deps + pre-commit hooks +# make dev — run API in dev mode with reload +# make test — run full test suite +# make lint — ruff check + mypy +# make docker-up — spin up all Docker services +# make docker-down — tear down Docker services +# =========================================================================== + +.DEFAULT_GOAL := help +SHELL := /bin/bash + +# Python / UV +PYTHON ?= python +UV ?= uv +PIP ?= pip + +# Docker +COMPOSE := docker compose + +# --------------------------------------------------------------------------- +# Help +# --------------------------------------------------------------------------- +.PHONY: help +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- +.PHONY: setup +setup: ## Install all deps (pip) + pre-commit hooks + $(PIP) install -e ".[all]" + pre-commit install + +.PHONY: setup-uv +setup-uv: ## Install all deps with UV + $(UV) pip install -e ".[all]" + pre-commit install + +# --------------------------------------------------------------------------- +# Development +# --------------------------------------------------------------------------- +.PHONY: dev +dev: ## Run API in dev mode (auto-reload) + uvicorn src.main:app --host 0.0.0.0 --port 8000 --reload + +.PHONY: gradio +gradio: ## Launch Gradio web UI + $(PYTHON) -m src.gradio_app + +.PHONY: telegram +telegram: ## Start Telegram bot + $(PYTHON) -c "from src.services.telegram.bot import MediGuardTelegramBot; MediGuardTelegramBot().run()" + +# --------------------------------------------------------------------------- +# Quality +# --------------------------------------------------------------------------- +.PHONY: lint +lint: ## Ruff check + MyPy + ruff check src/ tests/ + mypy src/ --ignore-missing-imports + +.PHONY: format +format: ## Ruff format + ruff format src/ tests/ + ruff check --fix src/ tests/ + +.PHONY: test +test: ## Run pytest with coverage + pytest tests/ -v --tb=short --cov=src --cov-report=term-missing + +.PHONY: test-quick +test-quick: ## Run only fast unit tests + pytest tests/ -v --tb=short -m "not slow" + +# --------------------------------------------------------------------------- +# Docker +# --------------------------------------------------------------------------- +.PHONY: docker-up +docker-up: ## Start all Docker services (detached) + $(COMPOSE) up -d + +.PHONY: docker-down +docker-down: ## Stop and remove Docker services + $(COMPOSE) down -v + +.PHONY: docker-build +docker-build: ## Build Docker images + $(COMPOSE) build + +.PHONY: docker-logs +docker-logs: ## Tail Docker logs + $(COMPOSE) logs -f + +# --------------------------------------------------------------------------- +# Database +# --------------------------------------------------------------------------- +.PHONY: db-upgrade +db-upgrade: ## Run Alembic migrations + alembic upgrade head + +.PHONY: db-revision +db-revision: ## Create a new Alembic migration + alembic revision --autogenerate -m "$(msg)" + +# --------------------------------------------------------------------------- +# Indexing +# --------------------------------------------------------------------------- +.PHONY: index-pdfs +index-pdfs: ## Parse and index all medical PDFs + $(PYTHON) -c "\ +from pathlib import Path; \ +from src.services.pdf_parser.service import make_pdf_parser_service; \ +from src.services.indexing.service import IndexingService; \ +from src.services.embeddings.service import make_embedding_service; \ +from src.services.opensearch.client import make_opensearch_client; \ +parser = make_pdf_parser_service(); \ +idx = IndexingService(make_embedding_service(), make_opensearch_client()); \ +docs = parser.parse_directory(Path('data/medical_pdfs')); \ +[idx.index_text(d.full_text, {'title': d.filename}) for d in docs if d.full_text]; \ +print(f'Indexed {len(docs)} documents')" + +# --------------------------------------------------------------------------- +# Clean +# --------------------------------------------------------------------------- +.PHONY: clean +clean: ## Remove build artifacts and caches + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .pytest_cache -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .mypy_cache -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .ruff_cache -exec rm -rf {} + 2>/dev/null || true + rm -rf dist/ build/ *.egg-info diff --git a/airflow/dags/ingest_pdfs.py b/airflow/dags/ingest_pdfs.py new file mode 100644 index 0000000000000000000000000000000000000000..07c9fc9f19c743de4233a583e4c61f0a28bf5d7d --- /dev/null +++ b/airflow/dags/ingest_pdfs.py @@ -0,0 +1,64 @@ +""" +MediGuard AI — Airflow DAG: Ingest Medical PDFs + +Periodically scans the medical_pdfs directory, parses new PDFs, +chunks them, generates embeddings, and indexes into OpenSearch. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.python import PythonOperator + +default_args = { + "owner": "mediguard", + "retries": 2, + "retry_delay": timedelta(minutes=5), + "email_on_failure": False, +} + + +def _ingest_pdfs(**kwargs): + """Parse all PDFs and index into OpenSearch.""" + from pathlib import Path + + from src.services.embeddings.service import make_embedding_service + from src.services.indexing.service import IndexingService + from src.services.opensearch.client import make_opensearch_client + from src.services.pdf_parser.service import make_pdf_parser_service + from src.settings import get_settings + + settings = get_settings() + pdf_dir = Path(settings.medical_pdfs.directory) + + parser = make_pdf_parser_service() + embedding_svc = make_embedding_service() + os_client = make_opensearch_client() + indexing_svc = IndexingService(embedding_svc, os_client) + + docs = parser.parse_directory(pdf_dir) + indexed = 0 + for doc in docs: + if doc.full_text and not doc.error: + indexing_svc.index_text(doc.full_text, {"title": doc.filename}) + indexed += 1 + + print(f"Ingested {indexed}/{len(docs)} documents") + return {"total": len(docs), "indexed": indexed} + + +with DAG( + dag_id="mediguard_ingest_pdfs", + default_args=default_args, + description="Parse and index medical PDFs into OpenSearch", + schedule="@daily", + start_date=datetime(2025, 1, 1), + catchup=False, + tags=["mediguard", "indexing"], +) as dag: + ingest = PythonOperator( + task_id="ingest_medical_pdfs", + python_callable=_ingest_pdfs, + ) diff --git a/airflow/dags/sop_evolution.py b/airflow/dags/sop_evolution.py new file mode 100644 index 0000000000000000000000000000000000000000..31e20d2901ced5c90805f1c0faac8d4dd5312e58 --- /dev/null +++ b/airflow/dags/sop_evolution.py @@ -0,0 +1,43 @@ +""" +MediGuard AI — Airflow DAG: SOP Evolution Cycle + +Runs the evolutionary SOP optimisation loop periodically. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.python import PythonOperator + +default_args = { + "owner": "mediguard", + "retries": 1, + "retry_delay": timedelta(minutes=10), + "email_on_failure": False, +} + + +def _run_evolution(**kwargs): + """Execute one SOP evolution cycle.""" + from src.evolution.director import run_evolution_cycle + + result = run_evolution_cycle() + print(f"Evolution cycle complete: {result}") + return result + + +with DAG( + dag_id="mediguard_sop_evolution", + default_args=default_args, + description="Run SOP evolutionary optimisation", + schedule="@weekly", + start_date=datetime(2025, 1, 1), + catchup=False, + tags=["mediguard", "evolution"], +) as dag: + evolve = PythonOperator( + task_id="run_sop_evolution", + python_callable=_run_evolution, + ) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..6fd2667772b8d6fa2f543210f12ff40d7bc8aa19 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,166 @@ +# =========================================================================== +# MediGuard AI — Docker Compose (development / CI) +# =========================================================================== +# Usage: +# docker compose up -d — start all services +# docker compose down -v — stop and remove volumes +# docker compose logs -f api — follow API logs +# =========================================================================== + +services: + # ----------------------------------------------------------------------- + # Application + # ----------------------------------------------------------------------- + api: + build: + context: . + dockerfile: Dockerfile + target: production + container_name: mediguard-api + ports: + - "${API_PORT:-8000}:8000" + env_file: .env + environment: + - POSTGRES__HOST=postgres + - OPENSEARCH__HOST=opensearch + - OPENSEARCH__PORT=9200 + - REDIS__HOST=redis + - REDIS__PORT=6379 + - OLLAMA__BASE_URL=http://ollama:11434 + - LANGFUSE__HOST=http://langfuse:3000 + depends_on: + postgres: + condition: service_healthy + opensearch: + condition: service_healthy + redis: + condition: service_healthy + volumes: + - ./data/medical_pdfs:/app/data/medical_pdfs:ro + restart: unless-stopped + + gradio: + build: + context: . + dockerfile: Dockerfile + target: production + container_name: mediguard-gradio + command: python -m src.gradio_app + ports: + - "${GRADIO_PORT:-7860}:7860" + environment: + - MEDIGUARD_API_URL=http://api:8000 + depends_on: + - api + restart: unless-stopped + + # ----------------------------------------------------------------------- + # Backing services + # ----------------------------------------------------------------------- + postgres: + image: postgres:16-alpine + container_name: mediguard-postgres + environment: + POSTGRES_DB: ${POSTGRES__DATABASE:-mediguard} + POSTGRES_USER: ${POSTGRES__USER:-mediguard} + POSTGRES_PASSWORD: ${POSTGRES__PASSWORD:-mediguard_secret} + ports: + - "${POSTGRES_PORT:-5432}:5432" + volumes: + - pg_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U mediguard"] + interval: 5s + timeout: 3s + retries: 10 + restart: unless-stopped + + opensearch: + image: opensearchproject/opensearch:2.19.0 + container_name: mediguard-opensearch + environment: + - discovery.type=single-node + - DISABLE_SECURITY_PLUGIN=true + - "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m" + - bootstrap.memory_lock=true + ulimits: + memlock: { soft: -1, hard: -1 } + nofile: { soft: 65536, hard: 65536 } + ports: + - "${OPENSEARCH_PORT:-9200}:9200" + volumes: + - os_data:/usr/share/opensearch/data + healthcheck: + test: ["CMD-SHELL", "curl -sf http://localhost:9200/_cluster/health || exit 1"] + interval: 10s + timeout: 5s + retries: 20 + restart: unless-stopped + + opensearch-dashboards: + image: opensearchproject/opensearch-dashboards:2.19.0 + container_name: mediguard-os-dashboards + environment: + - OPENSEARCH_HOSTS=["http://opensearch:9200"] + - DISABLE_SECURITY_DASHBOARDS_PLUGIN=true + ports: + - "${OS_DASHBOARDS_PORT:-5601}:5601" + depends_on: + opensearch: + condition: service_healthy + restart: unless-stopped + + redis: + image: redis:7-alpine + container_name: mediguard-redis + ports: + - "${REDIS_PORT:-6379}:6379" + volumes: + - redis_data:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 3s + retries: 10 + restart: unless-stopped + + ollama: + image: ollama/ollama:latest + container_name: mediguard-ollama + ports: + - "${OLLAMA_PORT:-11434}:11434" + volumes: + - ollama_data:/root/.ollama + restart: unless-stopped + # Uncomment for GPU support: + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + + # ----------------------------------------------------------------------- + # Observability + # ----------------------------------------------------------------------- + langfuse: + image: langfuse/langfuse:2 + container_name: mediguard-langfuse + environment: + - DATABASE_URL=postgresql://mediguard:mediguard_secret@postgres:5432/langfuse + - NEXTAUTH_URL=http://localhost:3000 + - NEXTAUTH_SECRET=mediguard-langfuse-secret-change-me + - SALT=mediguard-langfuse-salt-change-me + ports: + - "${LANGFUSE_PORT:-3000}:3000" + depends_on: + postgres: + condition: service_healthy + restart: unless-stopped + +volumes: + pg_data: + os_data: + redis_data: + ollama_data: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..4388e198e05679b3e6c440b7e571c7af1697bdd0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,117 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "mediguard-ai" +version = "2.0.0" +description = "Production medical biomarker analysis — agentic RAG + multi-agent workflow" +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.11" +authors = [{ name = "MediGuard AI Team" }] + +dependencies = [ + # --- Core --- + "fastapi>=0.115.0", + "uvicorn[standard]>=0.30.0", + "pydantic>=2.9.0", + "pydantic-settings>=2.5.0", + # --- LLM / LangChain --- + "langchain>=0.3.0", + "langchain-community>=0.3.0", + "langgraph>=0.2.0", + # --- Vector / Search --- + "opensearch-py>=2.7.0", + "faiss-cpu>=1.8.0", + # --- Embeddings --- + "httpx>=0.27.0", + # --- Database --- + "sqlalchemy>=2.0.0", + "psycopg2-binary>=2.9.0", + "alembic>=1.13.0", + # --- Cache --- + "redis>=5.0.0", + # --- PDF --- + "pypdf>=4.0.0", + # --- Observability --- + "langfuse>=2.0.0", + # --- Utilities --- + "python-dotenv>=1.0.0", + "tenacity>=8.0.0", +] + +[project.optional-dependencies] +docling = ["docling>=2.0.0"] +telegram = ["python-telegram-bot>=21.0", "httpx>=0.27.0"] +gradio = ["gradio>=5.0.0", "httpx>=0.27.0"] +airflow = ["apache-airflow>=2.9.0"] +google = ["langchain-google-genai>=2.0.0"] +groq = ["langchain-groq>=0.2.0"] +huggingface = ["sentence-transformers>=3.0.0"] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=5.0.0", + "ruff>=0.7.0", + "mypy>=1.12.0", + "pre-commit>=3.8.0", + "httpx>=0.27.0", +] +all = [ + "mediguard-ai[docling,telegram,gradio,google,groq,huggingface,dev]", +] + +[project.scripts] +mediguard = "src.main:app" +mediguard-telegram = "src.services.telegram.bot:MediGuardTelegramBot" +mediguard-gradio = "src.gradio_app:launch_gradio" + +# -------------------------------------------------------------------------- +# Ruff +# -------------------------------------------------------------------------- +[tool.ruff] +target-version = "py311" +line-length = 120 +fix = true + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "RUF", # ruff-specific +] +ignore = [ + "E501", # line too long — handled by formatter + "B008", # do not perform function calls in argument defaults (Depends) + "SIM108", # ternary operator +] + +[tool.ruff.lint.isort] +known-first-party = ["src"] + +# -------------------------------------------------------------------------- +# MyPy +# -------------------------------------------------------------------------- +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false # gradually enable +ignore_missing_imports = true + +# -------------------------------------------------------------------------- +# Pytest +# -------------------------------------------------------------------------- +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] +addopts = "-v --tb=short -q" +filterwarnings = ["ignore::DeprecationWarning"] diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000000000000000000000000000000000000..6111e83049b25728ab827313378e402824733591 --- /dev/null +++ b/src/database.py @@ -0,0 +1,50 @@ +""" +MediGuard AI — Database layer + +Provides SQLAlchemy engine/session factories and the declarative Base. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Generator + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker, DeclarativeBase + +from src.settings import get_settings + + +class Base(DeclarativeBase): + """Shared declarative base for all ORM models.""" + pass + + +@lru_cache(maxsize=1) +def _engine(): + settings = get_settings() + return create_engine( + settings.postgres.database_url, + pool_pre_ping=True, + pool_size=5, + max_overflow=10, + echo=settings.debug, + ) + + +@lru_cache(maxsize=1) +def _session_factory() -> sessionmaker[Session]: + return sessionmaker(bind=_engine(), autocommit=False, autoflush=False) + + +def get_db() -> Generator[Session, None, None]: + """FastAPI dependency — yields a DB session and commits/rolls back.""" + session = _session_factory()() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/dependencies.py b/src/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..77f35d9756a7f03569e33eea1bee6164c48b4cc4 --- /dev/null +++ b/src/dependencies.py @@ -0,0 +1,36 @@ +""" +MediGuard AI — FastAPI Dependency Injection + +Provides factory functions and ``Depends()`` for services used across routers. +""" + +from __future__ import annotations + +from functools import lru_cache + +from src.settings import Settings, get_settings +from src.services.cache.redis_cache import RedisCache, make_redis_cache +from src.services.embeddings.service import EmbeddingService, make_embedding_service +from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer +from src.services.ollama.client import OllamaClient, make_ollama_client +from src.services.opensearch.client import OpenSearchClient, make_opensearch_client + + +def get_opensearch_client() -> OpenSearchClient: + return make_opensearch_client() + + +def get_embedding_service() -> EmbeddingService: + return make_embedding_service() + + +def get_redis_cache() -> RedisCache: + return make_redis_cache() + + +def get_ollama_client() -> OllamaClient: + return make_ollama_client() + + +def get_langfuse_tracer() -> LangfuseTracer: + return make_langfuse_tracer() diff --git a/src/exceptions.py b/src/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..ff58d21c4d2a647763985de61e6e43fc60079b6a --- /dev/null +++ b/src/exceptions.py @@ -0,0 +1,149 @@ +""" +MediGuard AI — Domain Exception Hierarchy + +Production-grade exception classes for the medical RAG system. +Each service layer raises its own exception type so callers can handle +failures precisely without leaking implementation details. +""" + +from typing import Any, Dict, Optional + + +# ── Base ────────────────────────────────────────────────────────────────────── + +class MediGuardError(Exception): + """Root exception for the entire MediGuard AI application.""" + + def __init__(self, message: str = "", *, details: Optional[Dict[str, Any]] = None): + self.details = details or {} + super().__init__(message) + + +# ── Configuration / startup ────────────────────────────────────────────────── + +class ConfigurationError(MediGuardError): + """Raised when a required setting is missing or invalid.""" + + +class ServiceInitError(MediGuardError): + """Raised when a service fails to initialise during app startup.""" + + +# ── Database ───────────────────────────────────────────────────────────────── + +class DatabaseError(MediGuardError): + """Base class for all database-related errors.""" + + +class ConnectionError(DatabaseError): + """Could not connect to PostgreSQL.""" + + +class RecordNotFoundError(DatabaseError): + """Expected record does not exist.""" + + +# ── Search engine ──────────────────────────────────────────────────────────── + +class SearchError(MediGuardError): + """Base class for search-engine (OpenSearch) errors.""" + + +class IndexNotFoundError(SearchError): + """The requested OpenSearch index does not exist.""" + + +class SearchQueryError(SearchError): + """The search query was malformed or returned an error.""" + + +# ── Embeddings ─────────────────────────────────────────────────────────────── + +class EmbeddingError(MediGuardError): + """Failed to generate embeddings.""" + + +class EmbeddingProviderError(EmbeddingError): + """The upstream embedding provider returned an error.""" + + +# ── PDF / document parsing ─────────────────────────────────────────────────── + +class PDFParsingError(MediGuardError): + """Base class for PDF-processing errors.""" + + +class PDFExtractionError(PDFParsingError): + """Could not extract text from a PDF document.""" + + +class PDFValidationError(PDFParsingError): + """Uploaded PDF failed validation (size, format, etc.).""" + + +# ── LLM / Ollama ───────────────────────────────────────────────────────────── + +class LLMError(MediGuardError): + """Base class for LLM-related errors.""" + + +class OllamaConnectionError(LLMError): + """Could not reach the Ollama server.""" + + +class OllamaModelNotFoundError(LLMError): + """The requested Ollama model is not pulled/available.""" + + +class LLMResponseError(LLMError): + """The LLM returned an unparseable or empty response.""" + + +# ── Biomarker domain ───────────────────────────────────────────────────────── + +class BiomarkerError(MediGuardError): + """Base class for biomarker-related errors.""" + + +class BiomarkerValidationError(BiomarkerError): + """A biomarker value is physiologically implausible.""" + + +class BiomarkerNotFoundError(BiomarkerError): + """The biomarker name is unknown to the system.""" + + +# ── Medical analysis / workflow ────────────────────────────────────────────── + +class AnalysisError(MediGuardError): + """The clinical-analysis workflow encountered an error.""" + + +class GuardrailError(MediGuardError): + """A safety guardrail was triggered (input or output).""" + + +class OutOfScopeError(GuardrailError): + """The user query falls outside the medical domain.""" + + +# ── Cache ──────────────────────────────────────────────────────────────────── + +class CacheError(MediGuardError): + """Base class for cache (Redis) errors.""" + + +class CacheConnectionError(CacheError): + """Could not connect to Redis.""" + + +# ── Observability ──────────────────────────────────────────────────────────── + +class ObservabilityError(MediGuardError): + """Langfuse or metrics reporting failed (non-fatal).""" + + +# ── Telegram bot ───────────────────────────────────────────────────────────── + +class TelegramError(MediGuardError): + """Error from the Telegram bot integration.""" diff --git a/src/gradio_app.py b/src/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..fee2e2d31434d4e7753e389e613764f7dae83366 --- /dev/null +++ b/src/gradio_app.py @@ -0,0 +1,121 @@ +""" +MediGuard AI — Gradio Web UI + +Provides a simple chat interface and biomarker analysis panel. +""" + +from __future__ import annotations + +import json +import logging +import os + +import httpx + +logger = logging.getLogger(__name__) + +API_BASE = os.getenv("MEDIGUARD_API_URL", "http://localhost:8000") + + +def _call_ask(question: str) -> str: + """Call the /ask endpoint.""" + try: + with httpx.Client(timeout=60.0) as client: + resp = client.post(f"{API_BASE}/ask", json={"question": question}) + resp.raise_for_status() + return resp.json().get("answer", "No answer returned.") + except Exception as exc: + return f"Error: {exc}" + + +def _call_analyze(biomarkers_json: str) -> str: + """Call the /analyze/structured endpoint.""" + try: + biomarkers = json.loads(biomarkers_json) + with httpx.Client(timeout=60.0) as client: + resp = client.post( + f"{API_BASE}/analyze/structured", + json={"biomarkers": biomarkers}, + ) + resp.raise_for_status() + data = resp.json() + summary = data.get("conversational_summary") or json.dumps(data, indent=2) + return summary + except json.JSONDecodeError: + return "Invalid JSON. Please enter biomarkers as: {\"Glucose\": 185, \"HbA1c\": 8.2}" + except Exception as exc: + return f"Error: {exc}" + + +def launch_gradio(share: bool = False) -> None: + """Launch the Gradio interface.""" + try: + import gradio as gr + except ImportError: + raise ImportError("gradio is required. Install: pip install gradio") + + with gr.Blocks(title="MediGuard AI", theme=gr.themes.Soft()) as demo: + gr.Markdown("# 🏥 MediGuard AI — Medical Analysis") + gr.Markdown( + "**Disclaimer**: This tool is for informational purposes only and does not " + "replace professional medical advice." + ) + + with gr.Tab("Ask a Question"): + question_input = gr.Textbox( + label="Medical Question", + placeholder="e.g., What does a high HbA1c level indicate?", + lines=3, + ) + ask_btn = gr.Button("Ask", variant="primary") + answer_output = gr.Textbox(label="Answer", lines=15, interactive=False) + ask_btn.click(fn=_call_ask, inputs=question_input, outputs=answer_output) + + with gr.Tab("Analyze Biomarkers"): + bio_input = gr.Textbox( + label="Biomarkers (JSON)", + placeholder='{"Glucose": 185, "HbA1c": 8.2, "Cholesterol": 210}', + lines=5, + ) + analyze_btn = gr.Button("Analyze", variant="primary") + analysis_output = gr.Textbox(label="Analysis", lines=20, interactive=False) + analyze_btn.click(fn=_call_analyze, inputs=bio_input, outputs=analysis_output) + + with gr.Tab("Search Knowledge Base"): + search_input = gr.Textbox( + label="Search Query", + placeholder="e.g., diabetes management guidelines", + lines=2, + ) + search_btn = gr.Button("Search", variant="primary") + search_output = gr.Textbox(label="Results", lines=15, interactive=False) + + def _call_search(query: str) -> str: + try: + with httpx.Client(timeout=30.0) as client: + resp = client.post( + f"{API_BASE}/search", + json={"query": query, "top_k": 5, "mode": "hybrid"}, + ) + resp.raise_for_status() + data = resp.json() + results = data.get("results", []) + if not results: + return "No results found." + parts = [] + for i, r in enumerate(results, 1): + parts.append( + f"**[{i}] {r.get('title', 'Untitled')}** (score: {r.get('score', 0):.3f})\n" + f"{r.get('text', '')}\n" + ) + return "\n---\n".join(parts) + except Exception as exc: + return f"Error: {exc}" + + search_btn.click(fn=_call_search, inputs=search_input, outputs=search_output) + + demo.launch(server_name="0.0.0.0", server_port=7860, share=share) + + +if __name__ == "__main__": + launch_gradio() diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..c60b542d7996160f238b7c9db4362fd85e121dfa --- /dev/null +++ b/src/main.py @@ -0,0 +1,220 @@ +""" +MediGuard AI — Production FastAPI Application + +Central app factory with lifespan that initialises all production services +(OpenSearch, Redis, Ollama, Langfuse, RAG pipeline) and gracefully shuts +them down. The existing ``api/`` package is kept as-is — this new module +becomes the primary production entry-point. +""" + +from __future__ import annotations + +import logging +import os +import time +from contextlib import asynccontextmanager +from datetime import datetime, timezone + +from fastapi import FastAPI, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from src.settings import get_settings + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(name)-30s | %(levelname)-7s | %(message)s", +) +logger = logging.getLogger("mediguard") + +# --------------------------------------------------------------------------- +# Lifespan +# --------------------------------------------------------------------------- + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialise production services on startup, tear them down on shutdown.""" + settings = get_settings() + app.state.start_time = time.time() + app.state.version = "2.0.0" + + logger.info("=" * 70) + logger.info("MediGuard AI — starting production server v%s", app.state.version) + logger.info("=" * 70) + + # --- OpenSearch --- + try: + from src.services.opensearch.client import make_opensearch_client + app.state.opensearch_client = make_opensearch_client() + logger.info("OpenSearch client ready") + except Exception as exc: + logger.warning("OpenSearch unavailable: %s", exc) + app.state.opensearch_client = None + + # --- Embedding service --- + try: + from src.services.embeddings.service import make_embedding_service + app.state.embedding_service = make_embedding_service() + logger.info("Embedding service ready (provider=%s)", app.state.embedding_service._provider) + except Exception as exc: + logger.warning("Embedding service unavailable: %s", exc) + app.state.embedding_service = None + + # --- Redis cache --- + try: + from src.services.cache.redis_cache import make_redis_cache + app.state.cache = make_redis_cache() + logger.info("Redis cache ready") + except Exception as exc: + logger.warning("Redis cache unavailable: %s", exc) + app.state.cache = None + + # --- Ollama LLM --- + try: + from src.services.ollama.client import make_ollama_client + app.state.ollama_client = make_ollama_client() + logger.info("Ollama client ready") + except Exception as exc: + logger.warning("Ollama client unavailable: %s", exc) + app.state.ollama_client = None + + # --- Langfuse tracer --- + try: + from src.services.langfuse.tracer import make_langfuse_tracer + app.state.tracer = make_langfuse_tracer() + logger.info("Langfuse tracer ready") + except Exception as exc: + logger.warning("Langfuse tracer unavailable: %s", exc) + app.state.tracer = None + + # --- Agentic RAG service --- + try: + from src.services.agents.agentic_rag import AgenticRAGService + from src.services.agents.context import AgenticContext + + if app.state.ollama_client and app.state.opensearch_client and app.state.embedding_service: + llm = app.state.ollama_client.get_langchain_model() + ctx = AgenticContext( + llm=llm, + embedding_service=app.state.embedding_service, + opensearch_client=app.state.opensearch_client, + cache=app.state.cache, + tracer=app.state.tracer, + ) + app.state.rag_service = AgenticRAGService(ctx) + logger.info("Agentic RAG service ready") + else: + app.state.rag_service = None + logger.warning("Agentic RAG service skipped — missing backing services") + except Exception as exc: + logger.warning("Agentic RAG service failed: %s", exc) + app.state.rag_service = None + + # --- Legacy RagBot service (backward-compatible /analyze) --- + try: + from api.app.services.ragbot import get_ragbot_service + ragbot = get_ragbot_service() + ragbot.initialize() + app.state.ragbot_service = ragbot + logger.info("Legacy RagBot service ready") + except Exception as exc: + logger.warning("Legacy RagBot service unavailable: %s", exc) + app.state.ragbot_service = None + + logger.info("All services initialised — ready to serve") + logger.info("=" * 70) + + yield # ---- server running ---- + + logger.info("Shutting down MediGuard AI …") + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + +def create_app() -> FastAPI: + """Build and return the configured FastAPI application.""" + settings = get_settings() + + app = FastAPI( + title="MediGuard AI", + description="Production medical biomarker analysis — agentic RAG + multi-agent workflow", + version="2.0.0", + lifespan=lifespan, + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + ) + + # --- CORS --- + origins = os.getenv("CORS_ALLOWED_ORIGINS", "*").split(",") + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=origins != ["*"], + allow_methods=["*"], + allow_headers=["*"], + ) + + # --- Exception handlers --- + @app.exception_handler(RequestValidationError) + async def validation_error(request: Request, exc: RequestValidationError): + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "status": "error", + "error_code": "VALIDATION_ERROR", + "message": "Request validation failed", + "details": exc.errors(), + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + + @app.exception_handler(Exception) + async def catch_all(request: Request, exc: Exception): + logger.error("Unhandled exception: %s", exc, exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "status": "error", + "error_code": "INTERNAL_SERVER_ERROR", + "message": "An unexpected error occurred. Please try again later.", + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + + # --- Routers --- + from src.routers import health, analyze, ask, search + + app.include_router(health.router) + app.include_router(analyze.router) + app.include_router(ask.router) + app.include_router(search.router) + + @app.get("/") + async def root(): + return { + "name": "MediGuard AI", + "version": "2.0.0", + "status": "online", + "endpoints": { + "health": "/health", + "health_ready": "/health/ready", + "analyze_natural": "/analyze/natural", + "analyze_structured": "/analyze/structured", + "ask": "/ask", + "search": "/search", + "docs": "/docs", + }, + } + + return app + + +# Module-level app for ``uvicorn src.main:app`` +app = create_app() diff --git a/src/repositories/__init__.py b/src/repositories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cee15d48ed0eb4ca756a0f10cfbce3031b666b4c --- /dev/null +++ b/src/repositories/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — Repositories package.""" diff --git a/src/repositories/analysis.py b/src/repositories/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0f06fca1241c09b8dccbe63ac52607245d3503 --- /dev/null +++ b/src/repositories/analysis.py @@ -0,0 +1,41 @@ +""" +MediGuard AI — Analysis repository (data-access layer). +""" + +from __future__ import annotations + +from typing import List, Optional + +from sqlalchemy.orm import Session + +from src.models.analysis import PatientAnalysis + + +class AnalysisRepository: + """CRUD operations for patient analyses.""" + + def __init__(self, db: Session): + self.db = db + + def create(self, analysis: PatientAnalysis) -> PatientAnalysis: + self.db.add(analysis) + self.db.flush() + return analysis + + def get_by_request_id(self, request_id: str) -> Optional[PatientAnalysis]: + return ( + self.db.query(PatientAnalysis) + .filter(PatientAnalysis.request_id == request_id) + .first() + ) + + def list_recent(self, limit: int = 20) -> List[PatientAnalysis]: + return ( + self.db.query(PatientAnalysis) + .order_by(PatientAnalysis.created_at.desc()) + .limit(limit) + .all() + ) + + def count(self) -> int: + return self.db.query(PatientAnalysis).count() diff --git a/src/repositories/document.py b/src/repositories/document.py new file mode 100644 index 0000000000000000000000000000000000000000..39115a631a041c46eb5c9dcfdba2d77a85fc1c6c --- /dev/null +++ b/src/repositories/document.py @@ -0,0 +1,48 @@ +""" +MediGuard AI — Document repository. +""" + +from __future__ import annotations + +from typing import List, Optional + +from sqlalchemy.orm import Session + +from src.models.analysis import MedicalDocument + + +class DocumentRepository: + """CRUD for ingested medical documents.""" + + def __init__(self, db: Session): + self.db = db + + def upsert(self, doc: MedicalDocument) -> MedicalDocument: + existing = ( + self.db.query(MedicalDocument) + .filter(MedicalDocument.content_hash == doc.content_hash) + .first() + ) + if existing: + existing.parse_status = doc.parse_status + existing.chunk_count = doc.chunk_count + existing.indexed_at = doc.indexed_at + self.db.flush() + return existing + self.db.add(doc) + self.db.flush() + return doc + + def get_by_id(self, doc_id: str) -> Optional[MedicalDocument]: + return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first() + + def list_all(self, limit: int = 100) -> List[MedicalDocument]: + return ( + self.db.query(MedicalDocument) + .order_by(MedicalDocument.created_at.desc()) + .limit(limit) + .all() + ) + + def count(self) -> int: + return self.db.query(MedicalDocument).count() diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fca7119d3e666ca87d1efac3e9e54341814db5ad --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — Production API routers.""" diff --git a/src/routers/analyze.py b/src/routers/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..573302b31bd368bf1a960b39884d4c0ddf504dd6 --- /dev/null +++ b/src/routers/analyze.py @@ -0,0 +1,88 @@ +""" +MediGuard AI — Analyze Router + +Backward-compatible /analyze/natural and /analyze/structured endpoints +that delegate to the existing ClinicalInsightGuild workflow. +""" + +from __future__ import annotations + +import logging +import time +import uuid +from datetime import datetime, timezone +from typing import Any, Dict + +from fastapi import APIRouter, HTTPException, Request + +from src.schemas.schemas import ( + AnalysisResponse, + ErrorResponse, + NaturalAnalysisRequest, + StructuredAnalysisRequest, +) + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/analyze", tags=["analysis"]) + + +async def _run_guild_analysis( + request: Request, + biomarkers: Dict[str, float], + patient_ctx: Dict[str, Any], + extracted_biomarkers: Dict[str, float] | None = None, +) -> AnalysisResponse: + """Execute the ClinicalInsightGuild and build the response envelope.""" + request_id = f"req_{uuid.uuid4().hex[:12]}" + t0 = time.time() + + ragbot = getattr(request.app.state, "ragbot_service", None) + if ragbot is None: + raise HTTPException(status_code=503, detail="Analysis service unavailable") + + try: + result = await ragbot.analyze(biomarkers, patient_ctx) + except Exception as exc: + logger.exception("Guild analysis failed: %s", exc) + raise HTTPException( + status_code=500, + detail=f"Analysis pipeline error: {exc}", + ) + + elapsed = (time.time() - t0) * 1000 + + # The guild returns a dict shaped like AnalysisResponse — pass through + return AnalysisResponse( + status="success", + request_id=request_id, + timestamp=datetime.now(timezone.utc).isoformat(), + extracted_biomarkers=extracted_biomarkers, + input_biomarkers=biomarkers, + patient_context=patient_ctx, + processing_time_ms=round(elapsed, 1), + **{k: v for k, v in result.items() if k not in ("status", "request_id", "timestamp", "extracted_biomarkers", "input_biomarkers", "patient_context", "processing_time_ms")}, + ) + + +@router.post("/natural", response_model=AnalysisResponse) +async def analyze_natural(body: NaturalAnalysisRequest, request: Request): + """Extract biomarkers from natural language and run full analysis.""" + extraction_svc = getattr(request.app.state, "extraction_service", None) + if extraction_svc is None: + raise HTTPException(status_code=503, detail="Extraction service unavailable") + + try: + extracted = await extraction_svc.extract_biomarkers(body.message) + except Exception as exc: + logger.exception("Biomarker extraction failed: %s", exc) + raise HTTPException(status_code=422, detail=f"Could not extract biomarkers: {exc}") + + patient_ctx = body.patient_context.model_dump(exclude_none=True) if body.patient_context else {} + return await _run_guild_analysis(request, extracted, patient_ctx, extracted_biomarkers=extracted) + + +@router.post("/structured", response_model=AnalysisResponse) +async def analyze_structured(body: StructuredAnalysisRequest, request: Request): + """Run full analysis on pre-structured biomarker data.""" + patient_ctx = body.patient_context.model_dump(exclude_none=True) if body.patient_context else {} + return await _run_guild_analysis(request, body.biomarkers, patient_ctx) diff --git a/src/routers/ask.py b/src/routers/ask.py new file mode 100644 index 0000000000000000000000000000000000000000..37b7897c9ee8e284704eb275cd40e738a9271e4e --- /dev/null +++ b/src/routers/ask.py @@ -0,0 +1,53 @@ +""" +MediGuard AI — Ask Router + +Free-form medical Q&A powered by the agentic RAG pipeline. +""" + +from __future__ import annotations + +import logging +import time +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, HTTPException, Request + +from src.schemas.schemas import AskRequest, AskResponse + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["ask"]) + + +@router.post("/ask", response_model=AskResponse) +async def ask_medical_question(body: AskRequest, request: Request): + """Answer a free-form medical question via agentic RAG.""" + rag_service = getattr(request.app.state, "rag_service", None) + if rag_service is None: + raise HTTPException(status_code=503, detail="RAG service unavailable") + + request_id = f"req_{uuid.uuid4().hex[:12]}" + t0 = time.time() + + try: + result = rag_service.ask( + query=body.question, + biomarkers=body.biomarkers, + patient_context=body.patient_context or "", + ) + except Exception as exc: + logger.exception("Agentic RAG failed: %s", exc) + raise HTTPException(status_code=500, detail=f"RAG pipeline error: {exc}") + + elapsed = (time.time() - t0) * 1000 + + return AskResponse( + status="success", + request_id=request_id, + question=body.question, + answer=result.get("final_answer", ""), + guardrail_score=result.get("guardrail_score"), + documents_retrieved=len(result.get("retrieved_documents", [])), + documents_relevant=len(result.get("relevant_documents", [])), + processing_time_ms=round(elapsed, 1), + ) diff --git a/src/routers/health.py b/src/routers/health.py new file mode 100644 index 0000000000000000000000000000000000000000..8b17330d2c0b9c84faef4ba848811e6e8b2065bf --- /dev/null +++ b/src/routers/health.py @@ -0,0 +1,101 @@ +""" +MediGuard AI — Health Router + +Provides /health and /health/ready with per-service checks. +""" + +from __future__ import annotations + +import time +from datetime import datetime, timezone + +from fastapi import APIRouter, Request + +from src.schemas.schemas import HealthResponse, ServiceHealth + +router = APIRouter(tags=["health"]) + + +@router.get("/health", response_model=HealthResponse) +async def health_check(request: Request) -> HealthResponse: + """Shallow liveness probe.""" + app_state = request.app.state + uptime = time.time() - getattr(app_state, "start_time", time.time()) + return HealthResponse( + status="healthy", + timestamp=datetime.now(timezone.utc).isoformat(), + version=getattr(app_state, "version", "2.0.0"), + uptime_seconds=round(uptime, 2), + ) + + +@router.get("/health/ready", response_model=HealthResponse) +async def readiness_check(request: Request) -> HealthResponse: + """Deep readiness probe — checks all backing services.""" + app_state = request.app.state + uptime = time.time() - getattr(app_state, "start_time", time.time()) + services: list[ServiceHealth] = [] + overall = "healthy" + + # --- OpenSearch --- + try: + os_client = getattr(app_state, "opensearch_client", None) + if os_client is not None: + t0 = time.time() + info = os_client.health() + latency = (time.time() - t0) * 1000 + os_status = info.get("status", "unknown") + services.append(ServiceHealth(name="opensearch", status="ok" if os_status in ("green", "yellow") else "degraded", latency_ms=round(latency, 1))) + else: + services.append(ServiceHealth(name="opensearch", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="opensearch", status="unavailable", detail=str(exc))) + overall = "degraded" + + # --- Redis --- + try: + cache = getattr(app_state, "cache", None) + if cache is not None: + t0 = time.time() + cache.set("__health__", "ok", ttl=10) + latency = (time.time() - t0) * 1000 + services.append(ServiceHealth(name="redis", status="ok", latency_ms=round(latency, 1))) + else: + services.append(ServiceHealth(name="redis", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="redis", status="unavailable", detail=str(exc))) + + # --- Ollama --- + try: + ollama = getattr(app_state, "ollama_client", None) + if ollama is not None: + t0 = time.time() + healthy = ollama.health() + latency = (time.time() - t0) * 1000 + services.append(ServiceHealth(name="ollama", status="ok" if healthy else "degraded", latency_ms=round(latency, 1))) + else: + services.append(ServiceHealth(name="ollama", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="ollama", status="unavailable", detail=str(exc))) + overall = "degraded" + + # --- Langfuse --- + try: + tracer = getattr(app_state, "tracer", None) + if tracer is not None: + services.append(ServiceHealth(name="langfuse", status="ok")) + else: + services.append(ServiceHealth(name="langfuse", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="langfuse", status="unavailable", detail=str(exc))) + + if any(s.status == "unavailable" for s in services if s.name in ("opensearch", "ollama")): + overall = "unhealthy" + + return HealthResponse( + status=overall, + timestamp=datetime.now(timezone.utc).isoformat(), + version=getattr(app_state, "version", "2.0.0"), + uptime_seconds=round(uptime, 2), + services=services, + ) diff --git a/src/routers/search.py b/src/routers/search.py new file mode 100644 index 0000000000000000000000000000000000000000..d3edab9a9b3068a66b781dcad3ad9b3f14257317 --- /dev/null +++ b/src/routers/search.py @@ -0,0 +1,72 @@ +""" +MediGuard AI — Search Router + +Direct hybrid search endpoint (no LLM generation). +""" + +from __future__ import annotations + +import logging +import time + +from fastapi import APIRouter, HTTPException, Request + +from src.schemas.schemas import SearchRequest, SearchResponse + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["search"]) + + +@router.post("/search", response_model=SearchResponse) +async def hybrid_search(body: SearchRequest, request: Request): + """Execute a direct hybrid search against the OpenSearch index.""" + os_client = getattr(request.app.state, "opensearch_client", None) + embedding_service = getattr(request.app.state, "embedding_service", None) + + if os_client is None: + raise HTTPException(status_code=503, detail="Search service unavailable") + + t0 = time.time() + + try: + if body.mode == "bm25": + results = os_client.search_bm25(query_text=body.query, top_k=body.top_k) + elif body.mode == "vector": + if embedding_service is None: + raise HTTPException(status_code=503, detail="Embedding service unavailable for vector search") + vec = embedding_service.embed_query(body.query) + results = os_client.search_vector(query_vector=vec, top_k=body.top_k) + else: + # hybrid + if embedding_service is None: + logger.warning("Embedding service unavailable — falling back to BM25") + results = os_client.search_bm25(query_text=body.query, top_k=body.top_k) + else: + vec = embedding_service.embed_query(body.query) + results = os_client.search_hybrid(query_text=body.query, query_vector=vec, top_k=body.top_k) + except HTTPException: + raise + except Exception as exc: + logger.exception("Search failed: %s", exc) + raise HTTPException(status_code=500, detail=f"Search error: {exc}") + + elapsed = (time.time() - t0) * 1000 + + formatted = [ + { + "id": hit.get("_id", ""), + "score": hit.get("_score", 0.0), + "title": hit.get("_source", {}).get("title", ""), + "section": hit.get("_source", {}).get("section_title", ""), + "text": hit.get("_source", {}).get("chunk_text", "")[:500], + } + for hit in results + ] + + return SearchResponse( + query=body.query, + mode=body.mode, + total_hits=len(formatted), + results=formatted, + processing_time_ms=round(elapsed, 1), + ) diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeeb3dd500a81b1183465f70a77f7943b4919a9e --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — API request/response schemas.""" diff --git a/src/schemas/schemas.py b/src/schemas/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..50bfe95d55ef9592e4e79577388015c027b00a9d --- /dev/null +++ b/src/schemas/schemas.py @@ -0,0 +1,247 @@ +""" +MediGuard AI — Production API Schemas + +Pydantic v2 request/response models for the new production API layer. +Keeps backward compatibility with existing schemas where possible. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +# ============================================================================ +# REQUEST MODELS +# ============================================================================ + + +class PatientContext(BaseModel): + """Patient demographic and context information.""" + + age: Optional[int] = Field(None, ge=0, le=120, description="Patient age in years") + gender: Optional[str] = Field(None, description="Patient gender (male/female)") + bmi: Optional[float] = Field(None, ge=10, le=60, description="Body Mass Index") + patient_id: Optional[str] = Field(None, description="Patient identifier") + + +class NaturalAnalysisRequest(BaseModel): + """Natural language biomarker analysis request.""" + + message: str = Field( + ..., min_length=5, max_length=2000, + description="Natural language message with biomarker values", + ) + patient_context: Optional[PatientContext] = Field( + default_factory=PatientContext, + ) + + +class StructuredAnalysisRequest(BaseModel): + """Structured biomarker analysis request.""" + + biomarkers: Dict[str, float] = Field( + ..., description="Dict of biomarker name → measured value", + ) + patient_context: Optional[PatientContext] = Field( + default_factory=PatientContext, + ) + + @field_validator("biomarkers") + @classmethod + def biomarkers_not_empty(cls, v: Dict[str, float]) -> Dict[str, float]: + if not v: + raise ValueError("biomarkers must contain at least one entry") + return v + + +class AskRequest(BaseModel): + """Free‑form medical question (agentic RAG pipeline).""" + + question: str = Field( + ..., min_length=3, max_length=4000, + description="Medical question", + ) + biomarkers: Optional[Dict[str, float]] = Field( + None, description="Optional biomarker context", + ) + patient_context: Optional[str] = Field( + None, description="Free‑text patient context", + ) + + +class SearchRequest(BaseModel): + """Direct hybrid search (no LLM generation).""" + + query: str = Field(..., min_length=2, max_length=1000) + top_k: int = Field(10, ge=1, le=100) + mode: str = Field("hybrid", description="Search mode: bm25 | vector | hybrid") + + +# ============================================================================ +# RESPONSE BUILDING BLOCKS +# ============================================================================ + + +class BiomarkerFlag(BaseModel): + name: str + value: float + unit: str + status: str + reference_range: str + warning: Optional[str] = None + + +class SafetyAlert(BaseModel): + severity: str + biomarker: Optional[str] = None + message: str + action: str + + +class KeyDriver(BaseModel): + biomarker: str + value: Any + contribution: Optional[str] = None + explanation: str + evidence: Optional[str] = None + + +class Prediction(BaseModel): + disease: str + confidence: float = Field(ge=0, le=1) + probabilities: Dict[str, float] + + +class DiseaseExplanation(BaseModel): + pathophysiology: str + citations: List[str] = Field(default_factory=list) + retrieved_chunks: Optional[List[Dict[str, Any]]] = None + + +class Recommendations(BaseModel): + immediate_actions: List[str] = Field(default_factory=list) + lifestyle_changes: List[str] = Field(default_factory=list) + monitoring: List[str] = Field(default_factory=list) + follow_up: Optional[str] = None + + +class ConfidenceAssessment(BaseModel): + prediction_reliability: str + evidence_strength: str + limitations: List[str] = Field(default_factory=list) + reasoning: Optional[str] = None + + +class AgentOutput(BaseModel): + agent_name: str + findings: Any + metadata: Optional[Dict[str, Any]] = None + execution_time_ms: Optional[float] = None + + +class Analysis(BaseModel): + biomarker_flags: List[BiomarkerFlag] + safety_alerts: List[SafetyAlert] + key_drivers: List[KeyDriver] + disease_explanation: DiseaseExplanation + recommendations: Recommendations + confidence_assessment: ConfidenceAssessment + alternative_diagnoses: Optional[List[Dict[str, Any]]] = None + + +# ============================================================================ +# TOP‑LEVEL RESPONSES +# ============================================================================ + + +class AnalysisResponse(BaseModel): + """Full clinical analysis response (backward‑compatible).""" + + status: str + request_id: str + timestamp: str + extracted_biomarkers: Optional[Dict[str, float]] = None + input_biomarkers: Dict[str, float] + patient_context: Dict[str, Any] + prediction: Prediction + analysis: Analysis + agent_outputs: List[AgentOutput] + workflow_metadata: Dict[str, Any] + conversational_summary: Optional[str] = None + processing_time_ms: float + sop_version: Optional[str] = None + + +class AskResponse(BaseModel): + """Response from the agentic RAG /ask endpoint.""" + + status: str = "success" + request_id: str + question: str + answer: str + guardrail_score: Optional[float] = None + documents_retrieved: int = 0 + documents_relevant: int = 0 + processing_time_ms: float = 0.0 + + +class SearchResponse(BaseModel): + """Direct hybrid search response.""" + + status: str = "success" + query: str + mode: str + total_hits: int + results: List[Dict[str, Any]] + processing_time_ms: float = 0.0 + + +class ErrorResponse(BaseModel): + """Error envelope.""" + + status: str = "error" + error_code: str + message: str + details: Optional[Dict[str, Any]] = None + timestamp: str + request_id: Optional[str] = None + + +# ============================================================================ +# HEALTH / INFO +# ============================================================================ + + +class ServiceHealth(BaseModel): + name: str + status: str # ok | degraded | unavailable + latency_ms: Optional[float] = None + detail: Optional[str] = None + + +class HealthResponse(BaseModel): + """Production health check.""" + + status: str # healthy | degraded | unhealthy + timestamp: str + version: str + uptime_seconds: float + services: List[ServiceHealth] = Field(default_factory=list) + + +class BiomarkerReferenceRange(BaseModel): + min: Optional[float] = None + max: Optional[float] = None + male: Optional[Dict[str, float]] = None + female: Optional[Dict[str, float]] = None + + +class BiomarkerInfo(BaseModel): + name: str + unit: str + normal_range: BiomarkerReferenceRange + critical_low: Optional[float] = None + critical_high: Optional[float] = None diff --git a/src/services/agents/__init__.py b/src/services/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8659cd7fd04e8a55784beba77783f0c75ec1911 --- /dev/null +++ b/src/services/agents/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — Agentic RAG agents package.""" diff --git a/src/services/agents/agentic_rag.py b/src/services/agents/agentic_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..c2fc62d8168f835f6250e0a972d421fc006a9f52 --- /dev/null +++ b/src/services/agents/agentic_rag.py @@ -0,0 +1,158 @@ +""" +MediGuard AI — Agentic RAG Orchestrator + +LangGraph StateGraph that wires all nodes into the guardrail → retrieve → grade → generate pipeline. +""" + +from __future__ import annotations + +import logging +from functools import lru_cache, partial +from typing import Any + +from langgraph.graph import END, StateGraph + +from src.services.agents.context import AgenticContext +from src.services.agents.nodes.generate_answer_node import generate_answer_node +from src.services.agents.nodes.grade_documents_node import grade_documents_node +from src.services.agents.nodes.guardrail_node import guardrail_node +from src.services.agents.nodes.out_of_scope_node import out_of_scope_node +from src.services.agents.nodes.retrieve_node import retrieve_node +from src.services.agents.nodes.rewrite_query_node import rewrite_query_node +from src.services.agents.state import AgenticRAGState + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Edge routing helpers +# --------------------------------------------------------------------------- + + +def _route_after_guardrail(state: dict) -> str: + """Decide path after guardrail evaluation.""" + if state.get("routing_decision") == "analyze": + # Biomarker analysis pathway — goes straight to retrieve + return "retrieve" + if state.get("is_in_scope"): + return "retrieve" + return "out_of_scope" + + +def _route_after_grading(state: dict) -> str: + """Decide whether to rewrite query or proceed to generation.""" + if state.get("needs_rewrite"): + return "rewrite_query" + if not state.get("relevant_documents"): + return "generate_answer" # will produce a "no evidence found" answer + return "generate_answer" + + +# --------------------------------------------------------------------------- +# Graph builder +# --------------------------------------------------------------------------- + + +def build_agentic_rag_graph(context: AgenticContext) -> Any: + """Construct the compiled LangGraph for the agentic RAG pipeline. + + Parameters + ---------- + context: + Runtime dependencies (LLM, OpenSearch, embeddings, cache, tracer). + + Returns + ------- + Compiled LangGraph graph ready for ``.invoke()`` / ``.stream()``. + """ + workflow = StateGraph(AgenticRAGState) + + # Bind context to every node via functools.partial + workflow.add_node("guardrail", partial(guardrail_node, context=context)) + workflow.add_node("retrieve", partial(retrieve_node, context=context)) + workflow.add_node("grade_documents", partial(grade_documents_node, context=context)) + workflow.add_node("rewrite_query", partial(rewrite_query_node, context=context)) + workflow.add_node("generate_answer", partial(generate_answer_node, context=context)) + workflow.add_node("out_of_scope", partial(out_of_scope_node, context=context)) + + # Entry point + workflow.set_entry_point("guardrail") + + # Conditional edges + workflow.add_conditional_edges( + "guardrail", + _route_after_guardrail, + { + "retrieve": "retrieve", + "out_of_scope": "out_of_scope", + }, + ) + + workflow.add_edge("retrieve", "grade_documents") + + workflow.add_conditional_edges( + "grade_documents", + _route_after_grading, + { + "rewrite_query": "rewrite_query", + "generate_answer": "generate_answer", + }, + ) + + # After rewrite, loop back to retrieve + workflow.add_edge("rewrite_query", "retrieve") + + # Terminal edges + workflow.add_edge("generate_answer", END) + workflow.add_edge("out_of_scope", END) + + return workflow.compile() + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +class AgenticRAGService: + """High-level wrapper around the compiled RAG graph.""" + + def __init__(self, context: AgenticContext) -> None: + self._context = context + self._graph = build_agentic_rag_graph(context) + + def ask( + self, + query: str, + biomarkers: dict | None = None, + patient_context: str = "", + ) -> dict: + """Run the full agentic RAG pipeline and return the final state.""" + initial_state: dict[str, Any] = { + "query": query, + "biomarkers": biomarkers, + "patient_context": patient_context, + "errors": [], + } + + span = None + try: + if self._context.tracer: + span = self._context.tracer.start_span( + name="agentic_rag_ask", + metadata={"query": query}, + ) + result = self._graph.invoke(initial_state) + return result + except Exception as exc: + logger.error("Agentic RAG pipeline failed: %s", exc) + return { + **initial_state, + "final_answer": ( + "I apologize, but I'm temporarily unable to process your request. " + "Please consult a healthcare professional." + ), + "errors": [str(exc)], + } + finally: + if span is not None: + self._context.tracer.end_span(span) diff --git a/src/services/agents/context.py b/src/services/agents/context.py new file mode 100644 index 0000000000000000000000000000000000000000..83058d0c36b7056e6f1341b49b614406b369f8f3 --- /dev/null +++ b/src/services/agents/context.py @@ -0,0 +1,23 @@ +""" +MediGuard AI — Agentic RAG Context + +Runtime dependency injection dataclass — passed to every LangGraph node +so nodes can access services without globals. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass(frozen=True) +class AgenticContext: + """Immutable runtime context for agentic RAG nodes.""" + + llm: Any # LangChain chat model + embedding_service: Any # EmbeddingService + opensearch_client: Any # OpenSearchClient + cache: Any # RedisCache + tracer: Any # LangfuseTracer + guild: Optional[Any] = None # ClinicalInsightGuild (original workflow) diff --git a/src/services/agents/medical/__init__.py b/src/services/agents/medical/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e689127b711ba6347f08dc61aa9be4ea3681bed --- /dev/null +++ b/src/services/agents/medical/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — Medical agents (original 6 agents, re-exported).""" diff --git a/src/services/agents/nodes/__init__.py b/src/services/agents/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8035a8f90dea1d01bc65bb1220356b0eb9b097 --- /dev/null +++ b/src/services/agents/nodes/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — Agentic RAG nodes package.""" diff --git a/src/services/agents/nodes/generate_answer_node.py b/src/services/agents/nodes/generate_answer_node.py new file mode 100644 index 0000000000000000000000000000000000000000..ebda354daed31221ed87b1b7f1da5b2c11ab5276 --- /dev/null +++ b/src/services/agents/nodes/generate_answer_node.py @@ -0,0 +1,60 @@ +""" +MediGuard AI — Generate Answer Node + +Produces a RAG-grounded medical answer with citations. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from src.services.agents.prompts import RAG_GENERATION_SYSTEM + +logger = logging.getLogger(__name__) + + +def generate_answer_node(state: dict, *, context: Any) -> dict: + """Generate a cited medical answer from relevant documents.""" + query = state.get("rewritten_query") or state.get("query", "") + documents = state.get("relevant_documents", []) + biomarkers = state.get("biomarkers") + patient_context = state.get("patient_context", "") + + # Build evidence block + evidence_parts: list[str] = [] + for i, doc in enumerate(documents, 1): + title = doc.get("title", "Unknown") + section = doc.get("section", "") + text = doc.get("text", "")[:2000] + header = f"[{i}] {title}" + if section: + header += f" — {section}" + evidence_parts.append(f"{header}\n{text}") + evidence_block = "\n\n---\n\n".join(evidence_parts) if evidence_parts else "(No evidence retrieved)" + + # Build user message + user_msg = f"Question: {query}\n\n" + if biomarkers: + user_msg += f"Biomarkers: {biomarkers}\n\n" + if patient_context: + user_msg += f"Patient context: {patient_context}\n\n" + user_msg += f"Evidence:\n{evidence_block}" + + try: + response = context.llm.invoke( + [ + {"role": "system", "content": RAG_GENERATION_SYSTEM}, + {"role": "user", "content": user_msg}, + ] + ) + answer = response.content.strip() + except Exception as exc: + logger.error("Generation LLM failed: %s", exc) + answer = ( + "I apologize, but I'm temporarily unable to generate a response. " + "Please consult a healthcare professional for guidance." + ) + return {"final_answer": answer, "errors": [str(exc)]} + + return {"final_answer": answer} diff --git a/src/services/agents/nodes/grade_documents_node.py b/src/services/agents/nodes/grade_documents_node.py new file mode 100644 index 0000000000000000000000000000000000000000..3c60881cb3cfe274405dc3f16f96ffe92e57d502 --- /dev/null +++ b/src/services/agents/nodes/grade_documents_node.py @@ -0,0 +1,64 @@ +""" +MediGuard AI — Grade Documents Node + +Uses the LLM to judge whether each retrieved document is relevant to the query. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from src.services.agents.prompts import GRADING_SYSTEM + +logger = logging.getLogger(__name__) + + +def grade_documents_node(state: dict, *, context: Any) -> dict: + """Grade each retrieved document for relevance.""" + query = state.get("rewritten_query") or state.get("query", "") + documents = state.get("retrieved_documents", []) + + if not documents: + return { + "grading_results": [], + "relevant_documents": [], + "needs_rewrite": True, + } + + relevant: list[dict] = [] + grading_results: list[dict] = [] + + for doc in documents: + text = doc.get("text", "") + user_msg = f"Query: {query}\n\nDocument:\n{text[:2000]}" + try: + response = context.llm.invoke( + [ + {"role": "system", "content": GRADING_SYSTEM}, + {"role": "user", "content": user_msg}, + ] + ) + content = response.content.strip() + if "```" in content: + content = content.split("```")[1].split("```")[0].strip() + if content.startswith("json"): + content = content[4:].strip() + data = json.loads(content) + is_relevant = str(data.get("relevant", "false")).lower() == "true" + except Exception as exc: + logger.warning("Grading LLM failed for doc %s: %s — marking relevant", doc.get("id"), exc) + is_relevant = True # benefit of the doubt + + grading_results.append({"doc_id": doc.get("id"), "relevant": is_relevant}) + if is_relevant: + relevant.append(doc) + + needs_rewrite = len(relevant) < 2 and not state.get("rewritten_query") + + return { + "grading_results": grading_results, + "relevant_documents": relevant, + "needs_rewrite": needs_rewrite, + } diff --git a/src/services/agents/nodes/guardrail_node.py b/src/services/agents/nodes/guardrail_node.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea7f71956432cf20962a8f6d80349a08d216081 --- /dev/null +++ b/src/services/agents/nodes/guardrail_node.py @@ -0,0 +1,57 @@ +""" +MediGuard AI — Guardrail Node + +Validates that the user query is within the medical domain (score 0-100). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from src.services.agents.prompts import GUARDRAIL_SYSTEM + +logger = logging.getLogger(__name__) + + +def guardrail_node(state: dict, *, context: Any) -> dict: + """Score the query for medical relevance (0-100).""" + query = state.get("query", "") + biomarkers = state.get("biomarkers") + + # Fast path: if biomarkers are provided, it's definitely medical + if biomarkers: + return { + "guardrail_score": 95.0, + "is_in_scope": True, + "routing_decision": "analyze", + } + + try: + response = context.llm.invoke( + [ + {"role": "system", "content": GUARDRAIL_SYSTEM}, + {"role": "user", "content": query}, + ] + ) + content = response.content.strip() + # Parse JSON response + if "```" in content: + content = content.split("```")[1].split("```")[0].strip() + if content.startswith("json"): + content = content[4:].strip() + data = json.loads(content) + score = float(data.get("score", 0)) + except Exception as exc: + logger.warning("Guardrail LLM failed: %s — defaulting to in-scope", exc) + score = 70.0 # benefit of the doubt + + is_in_scope = score >= 40 + routing = "rag_answer" if is_in_scope else "out_of_scope" + + return { + "guardrail_score": score, + "is_in_scope": is_in_scope, + "routing_decision": routing, + } diff --git a/src/services/agents/nodes/out_of_scope_node.py b/src/services/agents/nodes/out_of_scope_node.py new file mode 100644 index 0000000000000000000000000000000000000000..63ce220cc5aa5273ca6539d829d97313bec5d0c3 --- /dev/null +++ b/src/services/agents/nodes/out_of_scope_node.py @@ -0,0 +1,16 @@ +""" +MediGuard AI — Out-of-Scope Node + +Returns a polite rejection for non-medical queries. +""" + +from __future__ import annotations + +from typing import Any + +from src.services.agents.prompts import OUT_OF_SCOPE_RESPONSE + + +def out_of_scope_node(state: dict, *, context: Any) -> dict: + """Return polite out-of-scope message.""" + return {"final_answer": OUT_OF_SCOPE_RESPONSE} diff --git a/src/services/agents/nodes/retrieve_node.py b/src/services/agents/nodes/retrieve_node.py new file mode 100644 index 0000000000000000000000000000000000000000..4b47cc36711ec085999fcf052ca6e049e3f5fb91 --- /dev/null +++ b/src/services/agents/nodes/retrieve_node.py @@ -0,0 +1,68 @@ +""" +MediGuard AI — Retrieve Node + +Performs hybrid search (BM25 + vector KNN) and merges results. +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def retrieve_node(state: dict, *, context: Any) -> dict: + """Retrieve documents from OpenSearch via hybrid search.""" + query = state.get("rewritten_query") or state.get("query", "") + + # 1. Try cache first + cache_key = f"retrieve:{query}" + if context.cache: + cached = context.cache.get(cache_key) + if cached is not None: + logger.debug("Cache hit for retrieve query") + return {"retrieved_documents": cached} + + # 2. Embed the query + try: + query_embedding = context.embedding_service.embed_query(query) + except Exception as exc: + logger.error("Embedding failed: %s", exc) + return {"retrieved_documents": [], "errors": [str(exc)]} + + # 3. Hybrid search + try: + results = context.opensearch_client.search_hybrid( + query_text=query, + query_vector=query_embedding, + top_k=10, + ) + except Exception as exc: + logger.error("OpenSearch hybrid search failed: %s — falling back to BM25", exc) + try: + results = context.opensearch_client.search_bm25( + query_text=query, + top_k=10, + ) + except Exception as exc2: + logger.error("BM25 fallback also failed: %s", exc2) + return {"retrieved_documents": [], "errors": [str(exc), str(exc2)]} + + documents = [ + { + "id": hit.get("_id", ""), + "score": hit.get("_score", 0.0), + "text": hit.get("_source", {}).get("chunk_text", ""), + "title": hit.get("_source", {}).get("title", ""), + "section": hit.get("_source", {}).get("section_title", ""), + "metadata": hit.get("_source", {}), + } + for hit in results + ] + + # 4. Store in cache (5 min TTL) + if context.cache: + context.cache.set(cache_key, documents, ttl=300) + + return {"retrieved_documents": documents} diff --git a/src/services/agents/nodes/rewrite_query_node.py b/src/services/agents/nodes/rewrite_query_node.py new file mode 100644 index 0000000000000000000000000000000000000000..71bd4c913b3ff23369dc1b974d499f4e534914f8 --- /dev/null +++ b/src/services/agents/nodes/rewrite_query_node.py @@ -0,0 +1,40 @@ +""" +MediGuard AI — Rewrite Query Node + +Reformulates the user query to improve retrieval recall. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from src.services.agents.prompts import REWRITE_SYSTEM + +logger = logging.getLogger(__name__) + + +def rewrite_query_node(state: dict, *, context: Any) -> dict: + """Rewrite the original query for better retrieval.""" + original = state.get("query", "") + patient_context = state.get("patient_context", "") + + user_msg = f"Original query: {original}" + if patient_context: + user_msg += f"\n\nPatient context: {patient_context}" + + try: + response = context.llm.invoke( + [ + {"role": "system", "content": REWRITE_SYSTEM}, + {"role": "user", "content": user_msg}, + ] + ) + rewritten = response.content.strip() + if not rewritten: + rewritten = original + except Exception as exc: + logger.warning("Rewrite LLM failed: %s — keeping original query", exc) + rewritten = original + + return {"rewritten_query": rewritten} diff --git a/src/services/agents/prompts.py b/src/services/agents/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6fd24e47f9a97d023543a0ce01553943bf59d0 --- /dev/null +++ b/src/services/agents/prompts.py @@ -0,0 +1,72 @@ +""" +MediGuard AI — Agentic RAG Prompts + +Medical-domain prompts for guardrail, grading, rewriting, and generation. +""" + +# ── Guardrail prompt ───────────────────────────────────────────────────────── + +GUARDRAIL_SYSTEM = """\ +You are a medical-domain classifier. Determine whether the user query is +about health, biomarkers, medical conditions, clinical guidelines, or +wellness — topics that MediGuard AI can help with. + +Score the query from 0 to 100: + 90-100 Clearly medical (biomarker values, disease questions, symptoms) + 60-89 Health-adjacent (nutrition, fitness, wellness) + 30-59 Loosely related (general biology, anatomy trivia) + 0-29 Not medical at all (weather, coding, sports) + +Respond ONLY with JSON: +{{"score": , "reason": ""}} +""" + +# ── Document grading prompt ────────────────────────────────────────────────── + +GRADING_SYSTEM = """\ +You are a medical-relevance grader. Given a user question and a retrieved +document chunk, decide whether the document is relevant to answering the +medical question. + +Respond ONLY with JSON: +{{"relevant": true/false, "reason": ""}} +""" + +# ── Query rewriting prompt ─────────────────────────────────────────────────── + +REWRITE_SYSTEM = """\ +You are a medical-query optimiser. The original user query did not +retrieve relevant medical documents. Rewrite it to improve retrieval from +a medical knowledge base. + +Guidelines: +- Use standard medical terminology +- Add synonyms for biomarker names +- Make the intent clearer + +Respond with ONLY the rewritten query (no explanation, no quotes). +""" + +# ── RAG generation prompt ──────────────────────────────────────────────────── + +RAG_GENERATION_SYSTEM = """\ +You are MediGuard AI, a clinical-information assistant. +Answer the user's medical question using ONLY the provided context documents. +If the context is insufficient, say so honestly. + +Rules: +1. Cite specific documents with [Source: filename, Page X]. +2. Use patient-friendly language. +3. Never provide a definitive diagnosis — use "may indicate", "suggests". +4. Always end with: "Please consult a healthcare professional for diagnosis." +5. If biomarker values are critical, highlight them as safety alerts. +""" + +# ── Out-of-scope response ─────────────────────────────────────────────────── + +OUT_OF_SCOPE_RESPONSE = ( + "I'm MediGuard AI — I specialise in medical biomarker analysis and " + "health-related questions. Your query doesn't appear to be about a " + "medical or health topic I can help with. Please try asking about " + "biomarker values, disease information, or clinical guidelines." +) diff --git a/src/services/agents/state.py b/src/services/agents/state.py new file mode 100644 index 0000000000000000000000000000000000000000..e87308c359bd0b5cd2592217a6d50f3bdd30a7d8 --- /dev/null +++ b/src/services/agents/state.py @@ -0,0 +1,47 @@ +""" +MediGuard AI — Agentic RAG State + +Enhanced LangGraph state for the guardrail → retrieve → grade → generate +pipeline that wraps the existing 6-agent clinical workflow. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Annotated +from typing_extensions import TypedDict +import operator + + +class AgenticRAGState(TypedDict): + """State flowing through the agentic RAG graph.""" + + # ── Input ──────────────────────────────────────────────────────────── + query: str + biomarkers: Optional[Dict[str, float]] + patient_context: Optional[Dict[str, Any]] + + # ── Guardrail ──────────────────────────────────────────────────────── + guardrail_score: float # 0-100 medical-relevance score + is_in_scope: bool # passed guardrail? + + # ── Retrieval ──────────────────────────────────────────────────────── + retrieved_documents: List[Dict[str, Any]] + retrieval_attempts: int + max_retrieval_attempts: int + + # ── Grading ────────────────────────────────────────────────────────── + grading_results: List[Dict[str, Any]] + relevant_documents: List[Dict[str, Any]] + needs_rewrite: bool + + # ── Rewriting ──────────────────────────────────────────────────────── + rewritten_query: Optional[str] + + # ── Generation / routing ───────────────────────────────────────────── + routing_decision: str # "analyze" | "rag_answer" | "out_of_scope" + final_answer: Optional[str] + analysis_result: Optional[Dict[str, Any]] + + # ── Metadata ───────────────────────────────────────────────────────── + trace_id: Optional[str] + errors: Annotated[List[str], operator.add] diff --git a/src/services/biomarker/__init__.py b/src/services/biomarker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6a6916e0528a2270c0da43ec30242d61d4f603 --- /dev/null +++ b/src/services/biomarker/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — Biomarker validation service.""" diff --git a/src/services/biomarker/service.py b/src/services/biomarker/service.py new file mode 100644 index 0000000000000000000000000000000000000000..84dd76d215459140c51e263335c7fd7742a7de02 --- /dev/null +++ b/src/services/biomarker/service.py @@ -0,0 +1,110 @@ +""" +MediGuard AI — Biomarker Validation Service + +Wraps the existing BiomarkerValidator as a production service with caching, +observability, and Pydantic-typed outputs. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Any, Dict, List, Optional + +from src.biomarker_validator import BiomarkerValidator +from src.biomarker_normalization import normalize_biomarker_name +from src.settings import get_settings + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BiomarkerResult: + """Validated result for a single biomarker.""" + + name: str + value: float + unit: str + status: str # NORMAL | HIGH | LOW | CRITICAL_HIGH | CRITICAL_LOW + reference_range: str + warning: Optional[str] = None + + +@dataclass +class ValidationReport: + """Complete biomarker validation report.""" + + results: List[BiomarkerResult] = field(default_factory=list) + safety_alerts: List[Dict[str, Any]] = field(default_factory=list) + recognized_count: int = 0 + unrecognized: List[str] = field(default_factory=list) + + +class BiomarkerService: + """Production biomarker validation service.""" + + def __init__(self) -> None: + self._validator = BiomarkerValidator() + + # --------------------------------------------------------------------- # + # Public API + # --------------------------------------------------------------------- # + + def validate( + self, + biomarkers: Dict[str, float], + gender: Optional[str] = None, + ) -> ValidationReport: + """Validate a dict of biomarker name → value and return a report.""" + report = ValidationReport() + + for raw_name, value in biomarkers.items(): + normalized = normalize_biomarker_name(raw_name) + flag = self._validator.validate_biomarker(normalized, value, gender=gender) + if flag is None: + report.unrecognized.append(raw_name) + continue + if flag.status == "UNKNOWN": + report.unrecognized.append(raw_name) + continue + report.recognized_count += 1 + report.results.append( + BiomarkerResult( + name=flag.name, + value=flag.value, + unit=flag.unit, + status=flag.status, + reference_range=flag.reference_range, + warning=flag.warning, + ) + ) + if flag.status.startswith("CRITICAL"): + report.safety_alerts.append( + { + "severity": "CRITICAL", + "biomarker": normalized, + "message": flag.warning or f"{normalized} is critically out of range", + "action": "Seek immediate medical attention", + } + ) + + return report + + def list_supported(self) -> List[Dict[str, Any]]: + """Return metadata for all supported biomarkers.""" + result = [] + for name, ref in self._validator.references.items(): + result.append({ + "name": name, + "unit": ref.get("unit", ""), + "normal_range": ref.get("normal_range", {}), + "critical_low": ref.get("critical_low"), + "critical_high": ref.get("critical_high"), + }) + return result + + +@lru_cache(maxsize=1) +def make_biomarker_service() -> BiomarkerService: + return BiomarkerService() diff --git a/src/services/cache/__init__.py b/src/services/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f3ff8596b70e870650f00263ee8e511da34320 --- /dev/null +++ b/src/services/cache/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI — Redis cache service package.""" +from src.services.cache.redis_cache import RedisCache, make_redis_cache + +__all__ = ["RedisCache", "make_redis_cache"] diff --git a/src/services/cache/redis_cache.py b/src/services/cache/redis_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d936062cd88fa2cb39ed70730e9c800022b99f --- /dev/null +++ b/src/services/cache/redis_cache.py @@ -0,0 +1,123 @@ +""" +MediGuard AI — Redis Cache + +Exact-match caching with SHA-256 keys for RAG and analysis responses. +Gracefully degrades when Redis is unavailable. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from functools import lru_cache +from typing import Any, Dict, Optional + +from src.settings import get_settings + +logger = logging.getLogger(__name__) + +try: + import redis as _redis +except ImportError: # pragma: no cover + _redis = None # type: ignore[assignment] + + +class RedisCache: + """Thin Redis wrapper with SHA-256 key generation and JSON ser/de.""" + + def __init__(self, client: Any, default_ttl: int = 21600): + self._client = client + self._default_ttl = default_ttl + self._enabled = client is not None + + @property + def enabled(self) -> bool: + return self._enabled + + def ping(self) -> bool: + if not self._enabled: + return False + try: + return self._client.ping() + except Exception: + return False + + @staticmethod + def _make_key(*parts: str) -> str: + raw = "|".join(parts) + return f"mediguard:{hashlib.sha256(raw.encode()).hexdigest()}" + + def get(self, *key_parts: str) -> Optional[Dict[str, Any]]: + if not self._enabled: + return None + key = self._make_key(*key_parts) + try: + value = self._client.get(key) + if value is None: + return None + return json.loads(value) + except Exception as exc: + logger.warning("Cache GET failed: %s", exc) + return None + + def set(self, value: Dict[str, Any], *key_parts: str, ttl: Optional[int] = None) -> bool: + if not self._enabled: + return False + key = self._make_key(*key_parts) + try: + self._client.setex(key, ttl or self._default_ttl, json.dumps(value, default=str)) + return True + except Exception as exc: + logger.warning("Cache SET failed: %s", exc) + return False + + def delete(self, *key_parts: str) -> bool: + if not self._enabled: + return False + key = self._make_key(*key_parts) + try: + self._client.delete(key) + return True + except Exception as exc: + logger.warning("Cache DELETE failed: %s", exc) + return False + + def flush(self) -> bool: + if not self._enabled: + return False + try: + self._client.flushdb() + return True + except Exception: + return False + + +class _NullCache(RedisCache): + """No-op cache returned when Redis is disabled or unavailable.""" + + def __init__(self): + super().__init__(client=None) + + +@lru_cache(maxsize=1) +def make_redis_cache() -> RedisCache: + """Factory — returns a live cache or a silent null-cache.""" + settings = get_settings() + if not settings.redis.enabled or _redis is None: + logger.info("Redis caching disabled") + return _NullCache() + try: + client = _redis.Redis( + host=settings.redis.host, + port=settings.redis.port, + db=settings.redis.db, + decode_responses=True, + socket_connect_timeout=3, + ) + client.ping() + logger.info("Redis connected (%s:%d)", settings.redis.host, settings.redis.port) + return RedisCache(client, settings.redis.ttl_seconds) + except Exception as exc: + logger.warning("Redis unavailable (%s), running without cache", exc) + return _NullCache() diff --git a/src/services/embeddings/__init__.py b/src/services/embeddings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a90f1ee3fbdc37f5fbf4fdfbc9865123bcb05437 --- /dev/null +++ b/src/services/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI — Embeddings service package.""" +from src.services.embeddings.service import EmbeddingService, make_embedding_service + +__all__ = ["EmbeddingService", "make_embedding_service"] diff --git a/src/services/embeddings/service.py b/src/services/embeddings/service.py new file mode 100644 index 0000000000000000000000000000000000000000..13c5ea8f3efe20c387099b091427966988849331 --- /dev/null +++ b/src/services/embeddings/service.py @@ -0,0 +1,147 @@ +""" +MediGuard AI — Embedding Service + +Supports Jina AI, Google, HuggingFace, and Ollama embeddings with +automatic fallback chain: Jina → Google → HuggingFace. +""" + +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import List + +from src.exceptions import EmbeddingError, EmbeddingProviderError +from src.settings import get_settings + +logger = logging.getLogger(__name__) + + +class EmbeddingService: + """Unified embedding interface — delegates to the configured provider.""" + + def __init__(self, model, provider_name: str, dimension: int): + self._model = model + self.provider_name = provider_name + self.dimension = dimension + + def embed_query(self, text: str) -> List[float]: + """Embed a single query text.""" + try: + return self._model.embed_query(text) + except Exception as exc: + raise EmbeddingProviderError(f"{self.provider_name} embed_query failed: {exc}") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Batch-embed a list of texts.""" + try: + return self._model.embed_documents(texts) + except Exception as exc: + raise EmbeddingProviderError(f"{self.provider_name} embed_documents failed: {exc}") + + +def _make_google_embeddings(): + settings = get_settings() + api_key = settings.embedding.google_api_key or settings.llm.google_api_key + if not api_key: + raise EmbeddingError("GOOGLE_API_KEY not set for Google embeddings") + from langchain_google_genai import GoogleGenerativeAIEmbeddings + + model = GoogleGenerativeAIEmbeddings( + model="models/text-embedding-004", + google_api_key=api_key, + ) + return EmbeddingService(model, "google", 768) + + +def _make_huggingface_embeddings(): + settings = get_settings() + try: + from langchain_huggingface import HuggingFaceEmbeddings + except ImportError: + from langchain_community.embeddings import HuggingFaceEmbeddings + + model = HuggingFaceEmbeddings(model_name=settings.embedding.huggingface_model) + return EmbeddingService(model, "huggingface", 384) + + +def _make_ollama_embeddings(): + settings = get_settings() + try: + from langchain_ollama import OllamaEmbeddings + except ImportError: + from langchain_community.embeddings import OllamaEmbeddings + + model = OllamaEmbeddings( + model=settings.ollama.embedding_model, + base_url=settings.ollama.host, + ) + return EmbeddingService(model, "ollama", 768) + + +def _make_jina_embeddings(): + settings = get_settings() + api_key = settings.embedding.jina_api_key + if not api_key: + raise EmbeddingError("JINA_API_KEY not set for Jina embeddings") + # Jina v3 via httpx (lightweight, no extra SDK) + import httpx + + class _JinaModel: + """Minimal Jina AI embedding adapter.""" + + def __init__(self, api_key: str, model: str): + self._api_key = api_key + self._model = model + self._url = "https://api.jina.ai/v1/embeddings" + + def _call(self, texts: list[str], task: str = "retrieval.passage") -> list[list[float]]: + headers = {"Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json"} + payload = {"model": self._model, "input": texts, "task": task} + resp = httpx.post(self._url, json=payload, headers=headers, timeout=60) + resp.raise_for_status() + data = resp.json()["data"] + return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])] + + def embed_query(self, text: str) -> list[float]: + return self._call([text], task="retrieval.query")[0] + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self._call(texts, task="retrieval.passage") + + model = _JinaModel(api_key, settings.embedding.jina_model) + return EmbeddingService(model, "jina", settings.embedding.dimension) + + +# ── Fallback chain factory ─────────────────────────────────────────────────── + +_PROVIDERS = { + "jina": _make_jina_embeddings, + "google": _make_google_embeddings, + "huggingface": _make_huggingface_embeddings, + "ollama": _make_ollama_embeddings, +} + +FALLBACK_ORDER = ["jina", "google", "huggingface"] + + +@lru_cache(maxsize=1) +def make_embedding_service() -> EmbeddingService: + """Create an embedding service with automatic fallback.""" + settings = get_settings() + preferred = settings.embedding.provider + + # Try preferred first, then fallbacks + order = [preferred] + [p for p in FALLBACK_ORDER if p != preferred] + for provider in order: + factory = _PROVIDERS.get(provider) + if factory is None: + continue + try: + svc = factory() + logger.info("Embedding provider: %s (dim=%d)", svc.provider_name, svc.dimension) + return svc + except Exception as exc: + logger.warning("Embedding provider '%s' failed: %s — trying next", provider, exc) + + raise EmbeddingError("All embedding providers failed. Check your API keys and configuration.") diff --git a/src/services/indexing/__init__.py b/src/services/indexing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b82f2806fdcad81850df40c1ac84f0fca2009ac --- /dev/null +++ b/src/services/indexing/__init__.py @@ -0,0 +1,5 @@ +"""MediGuard AI — Indexing (chunking + embedding + OpenSearch) package.""" +from src.services.indexing.text_chunker import MedicalTextChunker +from src.services.indexing.service import IndexingService + +__all__ = ["MedicalTextChunker", "IndexingService"] diff --git a/src/services/indexing/service.py b/src/services/indexing/service.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6af87e6b9a4bd9aa486c021526a345797f63f3 --- /dev/null +++ b/src/services/indexing/service.py @@ -0,0 +1,84 @@ +""" +MediGuard AI — Indexing Service + +Orchestrates: PDF parse → chunk → embed → index into OpenSearch. +""" + +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timezone +from typing import Dict, List + +from src.services.indexing.text_chunker import MedicalChunk, MedicalTextChunker + +logger = logging.getLogger(__name__) + + +class IndexingService: + """Coordinates chunking → embedding → OpenSearch indexing.""" + + def __init__(self, chunker, embedding_service, opensearch_client): + self.chunker = chunker + self.embedding_service = embedding_service + self.opensearch_client = opensearch_client + + def index_text( + self, + text: str, + *, + document_id: str = "", + title: str = "", + source_file: str = "", + ) -> int: + """Chunk, embed, and index a single document's text. Returns count of indexed chunks.""" + if not document_id: + document_id = str(uuid.uuid4()) + + chunks = self.chunker.chunk_text( + text, + document_id=document_id, + title=title, + source_file=source_file, + ) + if not chunks: + logger.warning("No chunks generated for document '%s'", title) + return 0 + + # Embed all chunks + texts = [c.text for c in chunks] + embeddings = self.embedding_service.embed_documents(texts) + + # Prepare OpenSearch documents + now = datetime.now(timezone.utc).isoformat() + docs: List[Dict] = [] + for chunk, emb in zip(chunks, embeddings): + doc = chunk.to_dict() + doc["_id"] = f"{document_id}_{chunk.chunk_index}" + doc["embedding"] = emb + doc["indexed_at"] = now + docs.append(doc) + + indexed = self.opensearch_client.bulk_index(docs) + logger.info( + "Indexed %d chunks for '%s' (document_id=%s)", + indexed, title, document_id, + ) + return indexed + + def index_chunks(self, chunks: List[MedicalChunk]) -> int: + """Embed and index pre-built chunks.""" + if not chunks: + return 0 + texts = [c.text for c in chunks] + embeddings = self.embedding_service.embed_documents(texts) + now = datetime.now(timezone.utc).isoformat() + docs: List[Dict] = [] + for chunk, emb in zip(chunks, embeddings): + doc = chunk.to_dict() + doc["_id"] = f"{chunk.document_id}_{chunk.chunk_index}" + doc["embedding"] = emb + doc["indexed_at"] = now + docs.append(doc) + return self.opensearch_client.bulk_index(docs) diff --git a/src/services/indexing/text_chunker.py b/src/services/indexing/text_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..9710214b16747e84ad4738551e72d35f11340490 --- /dev/null +++ b/src/services/indexing/text_chunker.py @@ -0,0 +1,178 @@ +""" +MediGuard AI — Medical-Aware Text Chunker + +Section-aware chunking with biomarker / condition metadata extraction. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set + +# Biomarker names to detect in chunk text +_BIOMARKER_NAMES: Set[str] = { + "Glucose", "Cholesterol", "Triglycerides", "HbA1c", "LDL", "HDL", + "Insulin", "BMI", "Hemoglobin", "Platelets", "WBC", "RBC", + "Hematocrit", "MCV", "MCH", "MCHC", "Heart Rate", "Systolic", + "Diastolic", "Troponin", "CRP", "C-reactive Protein", "ALT", "AST", + "Creatinine", "TSH", "T3", "T4", "Sodium", "Potassium", "Calcium", +} + +_CONDITION_KEYWORDS: Dict[str, str] = { + "diabetes": "diabetes", + "diabetic": "diabetes", + "hyperglycemia": "diabetes", + "insulin resistance": "diabetes", + "anemia": "anemia", + "anaemia": "anemia", + "iron deficiency": "anemia", + "thalassemia": "thalassemia", + "thalassaemia": "thalassemia", + "thrombocytopenia": "thrombocytopenia", + "heart disease": "heart_disease", + "cardiovascular": "heart_disease", + "coronary": "heart_disease", + "hypertension": "heart_disease", + "atherosclerosis": "heart_disease", + "hyperlipidemia": "heart_disease", +} + +_SECTION_RE = re.compile( + r"^(?:#+\s*)?(" + r"abstract|introduction|background|methods?|methodology|materials?" + r"|results?|findings|discussion|conclusion|summary" + r"|guidelines?|recommendations?|references?|bibliography" + r"|clinical\s*presentation|pathophysiology|diagnosis|treatment|prognosis" + r")\b", + re.IGNORECASE | re.MULTILINE, +) + + +@dataclass +class MedicalChunk: + """A single chunk with medical metadata.""" + text: str + chunk_index: int + document_id: str = "" + title: str = "" + source_file: str = "" + page_number: Optional[int] = None + section_title: str = "" + biomarkers_mentioned: List[str] = field(default_factory=list) + condition_tags: List[str] = field(default_factory=list) + word_count: int = 0 + + def to_dict(self) -> Dict: + return { + "chunk_text": self.text, + "chunk_index": self.chunk_index, + "document_id": self.document_id, + "title": self.title, + "source_file": self.source_file, + "page_number": self.page_number, + "section_title": self.section_title, + "biomarkers_mentioned": self.biomarkers_mentioned, + "condition_tags": self.condition_tags, + } + + +class MedicalTextChunker: + """Section-aware text chunker optimised for medical documents.""" + + def __init__( + self, + target_words: int = 600, + overlap_words: int = 100, + min_words: int = 50, + ): + self.target_words = target_words + self.overlap_words = overlap_words + self.min_words = min_words + + def chunk_text( + self, + text: str, + *, + document_id: str = "", + title: str = "", + source_file: str = "", + ) -> List[MedicalChunk]: + """Split text into enriched medical chunks.""" + sections = self._split_sections(text) + chunks: List[MedicalChunk] = [] + idx = 0 + for section_title, section_text in sections: + words = section_text.split() + if not words: + continue + start = 0 + while start < len(words): + end = min(start + self.target_words, len(words)) + chunk_words = words[start:end] + if len(chunk_words) < self.min_words and chunks: + # merge tiny tail into previous chunk + chunks[-1].text += " " + " ".join(chunk_words) + chunks[-1].word_count = len(chunks[-1].text.split()) + break + + chunk_text = " ".join(chunk_words) + biomarkers = self._detect_biomarkers(chunk_text) + conditions = self._detect_conditions(chunk_text) + + chunks.append( + MedicalChunk( + text=chunk_text, + chunk_index=idx, + document_id=document_id, + title=title, + source_file=source_file, + section_title=section_title, + biomarkers_mentioned=biomarkers, + condition_tags=conditions, + word_count=len(chunk_words), + ) + ) + idx += 1 + start = end - self.overlap_words if end < len(words) else len(words) + return chunks + + # ── internal helpers ───────────────────────────────────────────────── + + @staticmethod + def _split_sections(text: str) -> List[tuple[str, str]]: + """Split text by detected section headers.""" + matches = list(_SECTION_RE.finditer(text)) + if not matches: + return [("", text)] + sections: List[tuple[str, str]] = [] + # text before first section header + if matches[0].start() > 0: + preamble = text[: matches[0].start()].strip() + if preamble: + sections.append(("", preamble)) + for i, match in enumerate(matches): + header = match.group(1).strip().title() + start = match.end() + end = matches[i + 1].start() if i + 1 < len(matches) else len(text) + body = text[start:end].strip() + # Skip reference/bibliography sections + if header.lower() in ("references", "bibliography"): + continue + if body: + sections.append((header, body)) + return sections or [("", text)] + + @staticmethod + def _detect_biomarkers(text: str) -> List[str]: + text_lower = text.lower() + return sorted( + {name for name in _BIOMARKER_NAMES if name.lower() in text_lower} + ) + + @staticmethod + def _detect_conditions(text: str) -> List[str]: + text_lower = text.lower() + return sorted( + {tag for kw, tag in _CONDITION_KEYWORDS.items() if kw in text_lower} + ) diff --git a/src/services/langfuse/__init__.py b/src/services/langfuse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe206901714756d101b34827ac2e667f1fd15b4 --- /dev/null +++ b/src/services/langfuse/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI — Langfuse observability package.""" +from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer + +__all__ = ["LangfuseTracer", "make_langfuse_tracer"] diff --git a/src/services/langfuse/tracer.py b/src/services/langfuse/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0b9723a9a9b8b8debbe68d43c62be80e0f9188 --- /dev/null +++ b/src/services/langfuse/tracer.py @@ -0,0 +1,97 @@ +""" +MediGuard AI — Langfuse Observability Tracer + +Wraps Langfuse v3 SDK for end-to-end tracing of the RAG pipeline. +Silently no-ops when Langfuse is disabled or unreachable. +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from functools import lru_cache +from typing import Any, Dict, Optional + +from src.settings import get_settings + +logger = logging.getLogger(__name__) + +try: + from langfuse import Langfuse as _Langfuse +except ImportError: + _Langfuse = None # type: ignore[assignment,misc] + + +class LangfuseTracer: + """Thin wrapper around Langfuse for MediGuard pipeline tracing.""" + + def __init__(self, client: Any | None): + self._client = client + self._enabled = client is not None + + @property + def enabled(self) -> bool: + return self._enabled + + def trace(self, name: str, **kwargs: Any): + """Create a new trace (top-level span).""" + if not self._enabled: + return _NullSpan() + return self._client.trace(name=name, **kwargs) + + @contextmanager + def span(self, trace, name: str, **kwargs): + """Context manager for creating a span within a trace.""" + if not self._enabled or trace is None: + yield _NullSpan() + return + s = trace.span(name=name, **kwargs) + try: + yield s + finally: + s.end() + + def score(self, trace_id: str, name: str, value: float, comment: str = ""): + """Attach a score to a trace (for evaluation feedback).""" + if not self._enabled: + return + try: + self._client.score(trace_id=trace_id, name=name, value=value, comment=comment) + except Exception as exc: + logger.warning("Langfuse score failed: %s", exc) + + def flush(self): + if self._enabled: + try: + self._client.flush() + except Exception: + pass + + +class _NullSpan: + """Dummy span object that silently swallows calls.""" + + def __getattr__(self, name: str): + return lambda *a, **kw: _NullSpan() + + def end(self): + pass + + +@lru_cache(maxsize=1) +def make_langfuse_tracer() -> LangfuseTracer: + settings = get_settings() + if not settings.langfuse.enabled or _Langfuse is None: + logger.info("Langfuse tracing disabled") + return LangfuseTracer(None) + try: + client = _Langfuse( + public_key=settings.langfuse.public_key, + secret_key=settings.langfuse.secret_key, + host=settings.langfuse.host, + ) + logger.info("Langfuse connected (%s)", settings.langfuse.host) + return LangfuseTracer(client) + except Exception as exc: + logger.warning("Langfuse unavailable: %s", exc) + return LangfuseTracer(None) diff --git a/src/services/ollama/__init__.py b/src/services/ollama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb83880824eec8a57e410fbc70521ff6efcd99bb --- /dev/null +++ b/src/services/ollama/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI — Ollama client package.""" +from src.services.ollama.client import OllamaClient, make_ollama_client + +__all__ = ["OllamaClient", "make_ollama_client"] diff --git a/src/services/ollama/client.py b/src/services/ollama/client.py new file mode 100644 index 0000000000000000000000000000000000000000..fd99e74cc7dd9ea70d95308528e7a1552caa6f0a --- /dev/null +++ b/src/services/ollama/client.py @@ -0,0 +1,160 @@ +""" +MediGuard AI — Ollama Client + +Production-grade wrapper for the Ollama API with health checks, +streaming, and LangChain integration. +""" + +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import Any, Dict, Iterator, List, Optional + +import httpx + +from src.exceptions import OllamaConnectionError, OllamaModelNotFoundError +from src.settings import get_settings + +logger = logging.getLogger(__name__) + + +class OllamaClient: + """Wrapper around the Ollama REST API.""" + + def __init__(self, base_url: str, *, timeout: int = 120): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._http = httpx.Client(base_url=self.base_url, timeout=timeout) + + # ── Health ─────────────────────────────────────────────────────────── + + def ping(self) -> bool: + try: + resp = self._http.get("/api/version") + return resp.status_code == 200 + except Exception: + return False + + def health(self) -> Dict[str, Any]: + try: + resp = self._http.get("/api/version") + resp.raise_for_status() + return resp.json() + except Exception as exc: + raise OllamaConnectionError(f"Cannot reach Ollama: {exc}") + + def list_models(self) -> List[str]: + try: + resp = self._http.get("/api/tags") + resp.raise_for_status() + return [m["name"] for m in resp.json().get("models", [])] + except Exception as exc: + logger.warning("Failed to list Ollama models: %s", exc) + return [] + + # ── Generation ─────────────────────────────────────────────────────── + + def generate( + self, + prompt: str, + *, + model: Optional[str] = None, + system: str = "", + temperature: float = 0.0, + num_ctx: int = 8192, + ) -> Dict[str, Any]: + """Synchronous generation — returns the full response dict.""" + model = model or get_settings().ollama.model + payload: Dict[str, Any] = { + "model": model, + "prompt": prompt, + "stream": False, + "options": {"temperature": temperature, "num_ctx": num_ctx}, + } + if system: + payload["system"] = system + try: + resp = self._http.post("/api/generate", json=payload) + resp.raise_for_status() + return resp.json() + except httpx.HTTPStatusError as exc: + if exc.response.status_code == 404: + raise OllamaModelNotFoundError(f"Model '{model}' not found on Ollama server") + raise OllamaConnectionError(str(exc)) + except Exception as exc: + raise OllamaConnectionError(str(exc)) + + def generate_stream( + self, + prompt: str, + *, + model: Optional[str] = None, + system: str = "", + temperature: float = 0.0, + num_ctx: int = 8192, + ) -> Iterator[str]: + """Streaming generation — yields text tokens.""" + model = model or get_settings().ollama.model + payload: Dict[str, Any] = { + "model": model, + "prompt": prompt, + "stream": True, + "options": {"temperature": temperature, "num_ctx": num_ctx}, + } + if system: + payload["system"] = system + try: + with self._http.stream("POST", "/api/generate", json=payload) as resp: + resp.raise_for_status() + import json + for line in resp.iter_lines(): + if line: + data = json.loads(line) + token = data.get("response", "") + if token: + yield token + if data.get("done", False): + break + except Exception as exc: + raise OllamaConnectionError(str(exc)) + + # ── LangChain integration ──────────────────────────────────────────── + + def get_langchain_model( + self, + *, + model: Optional[str] = None, + temperature: float = 0.0, + json_mode: bool = False, + ): + """Return a LangChain ChatOllama instance.""" + model = model or get_settings().ollama.model + try: + from langchain_ollama import ChatOllama + except ImportError: + from langchain_community.chat_models import ChatOllama + + return ChatOllama( + model=model, + temperature=temperature, + base_url=self.base_url, + format="json" if json_mode else None, + ) + + def close(self): + self._http.close() + + +@lru_cache(maxsize=1) +def make_ollama_client() -> OllamaClient: + settings = get_settings() + client = OllamaClient( + base_url=settings.ollama.host, + timeout=settings.ollama.timeout, + ) + if client.ping(): + logger.info("Ollama connected at %s", settings.ollama.host) + else: + logger.warning("Ollama not reachable at %s", settings.ollama.host) + return client diff --git a/src/services/opensearch/__init__.py b/src/services/opensearch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c59b705cca78d8e1c05c61821226cb33d8fe0427 --- /dev/null +++ b/src/services/opensearch/__init__.py @@ -0,0 +1,5 @@ +"""MediGuard AI — OpenSearch service package.""" +from src.services.opensearch.client import OpenSearchClient, make_opensearch_client +from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING + +__all__ = ["OpenSearchClient", "make_opensearch_client", "MEDICAL_CHUNKS_MAPPING"] diff --git a/src/services/opensearch/client.py b/src/services/opensearch/client.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8b02bb41f2bab012dbafec8b30f603698df5df --- /dev/null +++ b/src/services/opensearch/client.py @@ -0,0 +1,224 @@ +""" +MediGuard AI — OpenSearch Client + +Production search-engine wrapper supporting BM25, vector (KNN), and +hybrid search with Reciprocal Rank Fusion (RRF). +""" + +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import Any, Dict, List, Optional + +from src.exceptions import IndexNotFoundError, SearchError, SearchQueryError +from src.settings import get_settings + +logger = logging.getLogger(__name__) + +# Guard import — opensearch-py is optional when running tests locally +try: + from opensearchpy import OpenSearch, RequestError, NotFoundError as OSNotFoundError +except ImportError: # pragma: no cover + OpenSearch = None # type: ignore[assignment,misc] + + +class OpenSearchClient: + """Thin wrapper around *opensearch-py* with medical-domain helpers.""" + + def __init__(self, client: "OpenSearch", index_name: str): + self._client = client + self.index_name = index_name + + # ── Health ─────────────────────────────────────────────────────────── + + def health(self) -> Dict[str, Any]: + return self._client.cluster.health() + + def ping(self) -> bool: + try: + return self._client.ping() + except Exception: + return False + + # ── Index management ───────────────────────────────────────────────── + + def ensure_index(self, mapping: Dict[str, Any]) -> None: + """Create the index if it doesn't already exist.""" + if not self._client.indices.exists(index=self.index_name): + self._client.indices.create(index=self.index_name, body=mapping) + logger.info("Created OpenSearch index '%s'", self.index_name) + else: + logger.info("OpenSearch index '%s' already exists", self.index_name) + + def delete_index(self) -> None: + if self._client.indices.exists(index=self.index_name): + self._client.indices.delete(index=self.index_name) + + def doc_count(self) -> int: + try: + resp = self._client.count(index=self.index_name) + return resp["count"] + except Exception: + return 0 + + # ── Indexing ───────────────────────────────────────────────────────── + + def index_document(self, doc_id: str, body: Dict[str, Any]) -> None: + self._client.index(index=self.index_name, id=doc_id, body=body) + + def bulk_index(self, documents: List[Dict[str, Any]]) -> int: + """Bulk-index a list of dicts, each must have an ``_id`` key.""" + if not documents: + return 0 + actions: list[Dict[str, Any]] = [] + for doc in documents: + doc_id = doc.pop("_id", None) + actions.append({"index": {"_index": self.index_name, "_id": doc_id}}) + actions.append(doc) + resp = self._client.bulk(body=actions, refresh="wait_for") + indexed = sum(1 for item in resp.get("items", []) if item.get("index", {}).get("status") in (200, 201)) + logger.info("Bulk-indexed %d / %d documents", indexed, len(documents)) + return indexed + + # ── BM25 search ────────────────────────────────────────────────────── + + def search_bm25( + self, + query: str, + *, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + body: Dict[str, Any] = { + "size": top_k, + "query": { + "bool": { + "must": [ + { + "multi_match": { + "query": query, + "fields": [ + "chunk_text^3", + "title^2", + "section_title^1.5", + "abstract^1", + ], + "type": "best_fields", + } + } + ] + } + }, + } + if filters: + body["query"]["bool"]["filter"] = self._build_filters(filters) + return self._execute_search(body) + + # ── Vector (KNN) search ────────────────────────────────────────────── + + def search_vector( + self, + embedding: List[float], + *, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + body: Dict[str, Any] = { + "size": top_k, + "query": { + "knn": { + "embedding": { + "vector": embedding, + "k": top_k, + } + } + }, + } + return self._execute_search(body) + + # ── Hybrid search (RRF) ───────────────────────────────────────────── + + def search_hybrid( + self, + query: str, + embedding: List[float], + *, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + bm25_weight: float = 0.4, + vector_weight: float = 0.6, + ) -> List[Dict[str, Any]]: + """Reciprocal Rank Fusion of BM25 + KNN results.""" + bm25_results = self.search_bm25(query, top_k=top_k, filters=filters) + vector_results = self.search_vector(embedding, top_k=top_k, filters=filters) + return self._rrf_fuse(bm25_results, vector_results, top_k=top_k) + + # ── Internal helpers ───────────────────────────────────────────────── + + def _execute_search(self, body: Dict[str, Any]) -> List[Dict[str, Any]]: + try: + resp = self._client.search(index=self.index_name, body=body) + except Exception as exc: + raise SearchQueryError(str(exc)) + hits = resp.get("hits", {}).get("hits", []) + return [ + { + "_id": h["_id"], + "_score": h.get("_score", 0.0), + **h.get("_source", {}), + } + for h in hits + ] + + @staticmethod + def _build_filters(filters: Dict[str, Any]) -> List[Dict[str, Any]]: + clauses: List[Dict[str, Any]] = [] + for key, value in filters.items(): + if isinstance(value, list): + clauses.append({"terms": {key: value}}) + else: + clauses.append({"term": {key: value}}) + return clauses + + @staticmethod + def _rrf_fuse( + results_a: List[Dict[str, Any]], + results_b: List[Dict[str, Any]], + *, + k: int = 60, + top_k: int = 10, + ) -> List[Dict[str, Any]]: + """Simple Reciprocal Rank Fusion.""" + scores: Dict[str, float] = {} + docs: Dict[str, Dict[str, Any]] = {} + for rank, doc in enumerate(results_a, 1): + doc_id = doc["_id"] + scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) + docs[doc_id] = doc + for rank, doc in enumerate(results_b, 1): + doc_id = doc["_id"] + scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) + docs[doc_id] = doc + ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] + return [ + {**docs[doc_id], "_score": score} + for doc_id, score in ranked + ] + + +# ── Factory ────────────────────────────────────────────────────────────────── + +@lru_cache(maxsize=1) +def make_opensearch_client() -> OpenSearchClient: + if OpenSearch is None: + raise SearchError("opensearch-py is not installed") + settings = get_settings() + os_settings = settings.opensearch + client = OpenSearch( + hosts=[os_settings.host], + http_auth=(os_settings.username, os_settings.password) if os_settings.username else None, + verify_certs=os_settings.verify_certs, + timeout=os_settings.timeout, + ) + return OpenSearchClient(client, os_settings.index_name) diff --git a/src/services/opensearch/index_config.py b/src/services/opensearch/index_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5bda139fb1dc1df579f9956048ac0dc980d1bf2f --- /dev/null +++ b/src/services/opensearch/index_config.py @@ -0,0 +1,88 @@ +""" +MediGuard AI — OpenSearch index mapping for medical document chunks. + +Includes a medical synonym analyzer and KNN vector field for hybrid search. +""" + +MEDICAL_CHUNKS_MAPPING: dict = { + "settings": { + "index": { + "knn": True, + "knn.algo_param.ef_search": 256, + }, + "number_of_shards": 1, + "number_of_replicas": 0, + "analysis": { + "filter": { + "medical_synonyms": { + "type": "synonym", + "synonyms": [ + "diabetes mellitus, DM, diabetes", + "HbA1c, glycated hemoglobin, hemoglobin A1c, A1c", + "glucose, blood sugar, blood glucose", + "LDL, low density lipoprotein, bad cholesterol", + "HDL, high density lipoprotein, good cholesterol", + "WBC, white blood cells, leukocytes", + "RBC, red blood cells, erythrocytes", + "MCV, mean corpuscular volume", + "BP, blood pressure", + "CRP, C-reactive protein", + "ALT, alanine aminotransferase, SGPT", + "AST, aspartate aminotransferase, SGOT", + "TSH, thyroid stimulating hormone", + "BMI, body mass index", + "anemia, anaemia", + "thrombocytopenia, low platelets", + "thalassemia, thalassaemia", + ], + } + }, + "analyzer": { + "medical_analyzer": { + "type": "custom", + "tokenizer": "standard", + "filter": [ + "lowercase", + "medical_synonyms", + "stop", + "snowball", + ], + } + }, + }, + }, + "mappings": { + "properties": { + # ── Text fields ──────────────────────────────────────── + "chunk_text": { + "type": "text", + "analyzer": "medical_analyzer", + }, + "title": {"type": "text", "analyzer": "medical_analyzer"}, + "section_title": {"type": "text"}, + "abstract": {"type": "text", "analyzer": "medical_analyzer"}, + # ── Keyword / filterable ─────────────────────────────── + "document_id": {"type": "keyword"}, + "document_type": {"type": "keyword"}, # guideline, research, reference + "condition_tags": {"type": "keyword"}, # diabetes, anemia, … + "biomarkers_mentioned": {"type": "keyword"}, # Glucose, HbA1c, … + "source_file": {"type": "keyword"}, + "page_number": {"type": "integer"}, + "chunk_index": {"type": "integer"}, + "publication_year": {"type": "integer"}, + # ── Vector (KNN) ─────────────────────────────────────── + "embedding": { + "type": "knn_vector", + "dimension": 1024, + "method": { + "name": "hnsw", + "space_type": "cosinesimil", + "engine": "nmslib", + "parameters": {"ef_construction": 256, "m": 48}, + }, + }, + # ── Timestamps ───────────────────────────────────────── + "indexed_at": {"type": "date"}, + } + }, +} diff --git a/src/services/pdf_parser/__init__.py b/src/services/pdf_parser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4214e5529685c7f600ccae795bf073b583ca9b1 --- /dev/null +++ b/src/services/pdf_parser/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — PDF parsing service.""" diff --git a/src/services/pdf_parser/service.py b/src/services/pdf_parser/service.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0bf49b9b01d2c8a758ff9bb6361fdb4f4e64c8 --- /dev/null +++ b/src/services/pdf_parser/service.py @@ -0,0 +1,162 @@ +""" +MediGuard AI — PDF Parser Service + +Production PDF parsing with Docling (preferred) falling back to PyPDF. +Returns structured text with section metadata. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass, field +from functools import lru_cache +from pathlib import Path +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class ParsedSection: + """One logical section extracted from a PDF.""" + + title: str + text: str + page_numbers: List[int] = field(default_factory=list) + + +@dataclass +class ParsedDocument: + """Result of parsing a single PDF.""" + + filename: str + content_hash: str + full_text: str + sections: List[ParsedSection] = field(default_factory=list) + page_count: int = 0 + error: Optional[str] = None + + +class PDFParserService: + """Unified PDF parsing with Docling → PyPDF fallback.""" + + def __init__(self) -> None: + self._has_docling = self._check_docling() + + @staticmethod + def _check_docling() -> bool: + try: + import docling # noqa: F401 + return True + except ImportError: + logger.info("Docling not installed — using PyPDF fallback") + return False + + def parse(self, path: Path) -> ParsedDocument: + """Parse a PDF file and return structured text.""" + if not path.exists(): + return ParsedDocument( + filename=path.name, + content_hash="", + full_text="", + error=f"File not found: {path}", + ) + + content_hash = hashlib.sha256(path.read_bytes()).hexdigest() + + if self._has_docling: + return self._parse_with_docling(path, content_hash) + return self._parse_with_pypdf(path, content_hash) + + # ------------------------------------------------------------------ # + # Docling (preferred) + # ------------------------------------------------------------------ # + + def _parse_with_docling(self, path: Path, content_hash: str) -> ParsedDocument: + try: + from docling.document_converter import DocumentConverter + + converter = DocumentConverter() + result = converter.convert(str(path)) + doc = result.document + + sections: list[ParsedSection] = [] + full_parts: list[str] = [] + + for element in doc.iterate_items(): + text = element.text if hasattr(element, "text") else str(element) + if text.strip(): + full_parts.append(text.strip()) + sections.append( + ParsedSection( + title=getattr(element, "label", ""), + text=text.strip(), + ) + ) + + full_text = "\n\n".join(full_parts) + return ParsedDocument( + filename=path.name, + content_hash=content_hash, + full_text=full_text, + sections=sections, + page_count=getattr(doc, "num_pages", 0), + ) + except Exception as exc: + logger.warning("Docling failed for %s — falling back to PyPDF: %s", path.name, exc) + return self._parse_with_pypdf(path, content_hash) + + # ------------------------------------------------------------------ # + # PyPDF fallback + # ------------------------------------------------------------------ # + + def _parse_with_pypdf(self, path: Path, content_hash: str) -> ParsedDocument: + try: + from pypdf import PdfReader + + reader = PdfReader(str(path)) + pages_text: list[str] = [] + for i, page in enumerate(reader.pages): + text = page.extract_text() or "" + if text.strip(): + pages_text.append(text.strip()) + + full_text = "\n\n".join(pages_text) + sections = [ + ParsedSection(title=f"Page {i + 1}", text=t, page_numbers=[i + 1]) + for i, t in enumerate(pages_text) + ] + + return ParsedDocument( + filename=path.name, + content_hash=content_hash, + full_text=full_text, + sections=sections, + page_count=len(reader.pages), + ) + except Exception as exc: + logger.error("PyPDF failed for %s: %s", path.name, exc) + return ParsedDocument( + filename=path.name, + content_hash=content_hash, + full_text="", + error=str(exc), + ) + + # ------------------------------------------------------------------ # + # Batch + # ------------------------------------------------------------------ # + + def parse_directory(self, directory: Path) -> List[ParsedDocument]: + """Parse all PDFs in a directory.""" + results: list[ParsedDocument] = [] + for pdf_path in sorted(directory.glob("*.pdf")): + logger.info("Parsing %s …", pdf_path.name) + results.append(self.parse(pdf_path)) + return results + + +@lru_cache(maxsize=1) +def make_pdf_parser_service() -> PDFParserService: + return PDFParserService() diff --git a/src/services/telegram/__init__.py b/src/services/telegram/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..250bebc5e8ebb9373c73ee4e32a1dafef74788fd --- /dev/null +++ b/src/services/telegram/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI — Telegram bot service.""" diff --git a/src/services/telegram/bot.py b/src/services/telegram/bot.py new file mode 100644 index 0000000000000000000000000000000000000000..01a69a5b00d0f0552fc969993016b9e7f1315adc --- /dev/null +++ b/src/services/telegram/bot.py @@ -0,0 +1,102 @@ +""" +MediGuard AI — Telegram Bot + +Lightweight Telegram bot that proxies user messages to the /ask endpoint. +Requires ``python-telegram-bot`` (installed via extras ``[telegram]``). +""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + +# Lazy import — only needed when the bot is actually started +_Application = None + + +def _get_telegram(): + global _Application + try: + from telegram import Update + from telegram.ext import Application, CommandHandler, MessageHandler, filters + _Application = Application + return Update, Application, CommandHandler, MessageHandler, filters + except ImportError: + raise ImportError( + "python-telegram-bot is required for the Telegram bot. " + "Install it with: pip install 'mediguard[telegram]' or pip install python-telegram-bot" + ) + + +class MediGuardTelegramBot: + """Telegram bot that wraps a ``requests`` call to the API ``/ask`` endpoint.""" + + def __init__( + self, + token: Optional[str] = None, + api_base_url: str = "http://localhost:8000", + ) -> None: + self._token = token or os.getenv("TELEGRAM_BOT_TOKEN", "") + self._api_base = api_base_url.rstrip("/") + + if not self._token: + raise ValueError("TELEGRAM_BOT_TOKEN is required") + + def run(self) -> None: + """Start the bot (blocking).""" + import httpx + + Update, Application, CommandHandler, MessageHandler, filters = _get_telegram() + + app = Application.builder().token(self._token).build() + + async def start_handler(update: Update, context) -> None: + await update.message.reply_text( + "Welcome to MediGuard AI! Send me a medical question or biomarker values " + "and I'll provide evidence-based insights.\n\n" + "Disclaimer: This is not a substitute for professional medical advice." + ) + + async def help_handler(update: Update, context) -> None: + await update.message.reply_text( + "Send me:\n" + "• A medical question (e.g. 'What does high HbA1c mean?')\n" + "• Biomarker values (e.g. 'My glucose is 180 and HbA1c 8.2')\n\n" + "I'll provide evidence-based analysis." + ) + + async def message_handler(update: Update, context) -> None: + user_text = update.message.text or "" + if not user_text.strip(): + return + + await update.message.reply_text("Analyzing… please wait.") + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{self._api_base}/ask", + json={"question": user_text}, + ) + resp.raise_for_status() + data = resp.json() + answer = data.get("answer", "Sorry, I could not generate an answer.") + except Exception as exc: + logger.error("Telegram→API call failed: %s", exc) + answer = "Sorry, I'm having trouble processing your request right now." + + # Telegram max message = 4096 chars + if len(answer) > 4000: + answer = answer[:4000] + "\n\n… (truncated)" + + await update.message.reply_text(answer) + + app.add_handler(CommandHandler("start", start_handler)) + app.add_handler(CommandHandler("help", help_handler)) + app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, message_handler)) + + logger.info("Telegram bot starting (polling mode)") + app.run_polling() diff --git a/src/settings.py b/src/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..f35897dc6cbb93cb3d952daa4456ab2ef9a35aae --- /dev/null +++ b/src/settings.py @@ -0,0 +1,186 @@ +""" +MediGuard AI — Pydantic Settings (hierarchical, env-driven) + +All runtime configuration lives here. Values are read from environment +variables (with ``env_nested_delimiter="__"``), so ``OPENSEARCH__HOST`` +maps to ``settings.opensearch.host``. + +Usage:: + + from src.settings import get_settings + settings = get_settings() + print(settings.opensearch.host) +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import List, Literal, Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +class _Base(BaseSettings): + """Shared Settings base with nested-env support.""" + + model_config = { + "env_nested_delimiter": "__", + "frozen": True, + "extra": "ignore", + } + + +# ── Sub-settings ───────────────────────────────────────────────────────────── + +class APISettings(_Base): + host: str = "0.0.0.0" + port: int = 8000 + reload: bool = False + workers: int = 4 + cors_origins: str = "*" + log_level: str = "INFO" + + model_config = {"env_prefix": "API__"} + + +class PostgresSettings(_Base): + database_url: str = "postgresql+psycopg2://mediguard:mediguard@localhost:5432/mediguard_db" + + model_config = {"env_prefix": "POSTGRES__"} + + +class OpenSearchSettings(_Base): + host: str = "http://localhost:9200" + index_name: str = "medical_chunks" + username: str = "" + password: str = "" + verify_certs: bool = False + timeout: int = 30 + + model_config = {"env_prefix": "OPENSEARCH__"} + + +class RedisSettings(_Base): + host: str = "localhost" + port: int = 6379 + db: int = 0 + ttl_seconds: int = 21600 # 6 hours default + enabled: bool = True + + model_config = {"env_prefix": "REDIS__"} + + +class OllamaSettings(_Base): + host: str = "http://localhost:11434" + model: str = "llama3.1:8b" + embedding_model: str = "nomic-embed-text" + timeout: int = 120 + num_ctx: int = 8192 + + model_config = {"env_prefix": "OLLAMA__"} + + +class LLMSettings(_Base): + provider: Literal["groq", "gemini", "ollama"] = "groq" + temperature: float = 0.0 + groq_api_key: str = "" + groq_model: str = "llama-3.3-70b-versatile" + google_api_key: str = "" + gemini_model: str = "gemini-2.0-flash" + + model_config = {"env_prefix": "LLM__"} + + +class EmbeddingSettings(_Base): + provider: Literal["jina", "google", "huggingface", "ollama"] = "google" + jina_api_key: str = "" + jina_model: str = "jina-embeddings-v3" + dimension: int = 1024 + google_api_key: str = "" + huggingface_model: str = "sentence-transformers/all-MiniLM-L6-v2" + batch_size: int = 64 + + model_config = {"env_prefix": "EMBEDDING__"} + + +class ChunkingSettings(_Base): + chunk_size: int = 600 # words + chunk_overlap: int = 100 # words + min_chunk_size: int = 50 + section_aware: bool = True + + model_config = {"env_prefix": "CHUNKING__"} + + +class LangfuseSettings(_Base): + enabled: bool = False + public_key: str = "" + secret_key: str = "" + host: str = "http://localhost:3001" + + model_config = {"env_prefix": "LANGFUSE__"} + + +class TelegramSettings(_Base): + enabled: bool = False + bot_token: str = "" + allowed_users: str = "" # comma-separated user IDs + + model_config = {"env_prefix": "TELEGRAM__"} + + +class BiomarkerSettings(_Base): + reference_file: str = "config/biomarker_references.json" + analyzer_threshold: float = 0.15 + critical_alert_mode: Literal["strict", "moderate", "permissive"] = "strict" + + model_config = {"env_prefix": "BIOMARKER__"} + + +class MedicalPDFSettings(_Base): + pdf_directory: str = "data/medical_pdfs" + vector_store_path: str = "data/vector_stores" + max_file_size_mb: int = 50 + max_pages: int = 500 + + model_config = {"env_prefix": "PDF__"} + + +# ── Root settings ──────────────────────────────────────────────────────────── + +class Settings(_Base): + """Root configuration — aggregates all sub-settings.""" + + app_name: str = "MediGuard AI" + app_version: str = "2.0.0" + environment: Literal["development", "staging", "production"] = "development" + debug: bool = False + + # Sub-settings (populated from env with nesting) + api: APISettings = Field(default_factory=APISettings) + postgres: PostgresSettings = Field(default_factory=PostgresSettings) + opensearch: OpenSearchSettings = Field(default_factory=OpenSearchSettings) + redis: RedisSettings = Field(default_factory=RedisSettings) + ollama: OllamaSettings = Field(default_factory=OllamaSettings) + llm: LLMSettings = Field(default_factory=LLMSettings) + embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings) + chunking: ChunkingSettings = Field(default_factory=ChunkingSettings) + langfuse: LangfuseSettings = Field(default_factory=LangfuseSettings) + telegram: TelegramSettings = Field(default_factory=TelegramSettings) + biomarker: BiomarkerSettings = Field(default_factory=BiomarkerSettings) + pdf: MedicalPDFSettings = Field(default_factory=MedicalPDFSettings) + + model_config = { + "env_nested_delimiter": "__", + "frozen": True, + "extra": "ignore", + } + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + """Cached factory — returns a single frozen ``Settings`` instance.""" + return Settings() diff --git a/tests/test_agentic_rag.py b/tests/test_agentic_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..30413e293c937daa8bf62be3f58d968faa584de7 --- /dev/null +++ b/tests/test_agentic_rag.py @@ -0,0 +1,202 @@ +""" +Tests for src/services/agents/ — agentic RAG pipeline. +""" + +import json +from dataclasses import dataclass +from typing import Any, Optional +from unittest.mock import MagicMock + +import pytest + + +# ----------------------------------------------------------------------- +# Mock context and LLM +# ----------------------------------------------------------------------- + + +class MockMessage: + def __init__(self, content: str): + self.content = content + + +class MockLLM: + """Programmable mock LLM that returns canned responses.""" + + def __init__(self, responses: list[str] | None = None): + self._responses = responses or [] + self._call_count = 0 + + def invoke(self, messages: list) -> MockMessage: + if self._call_count < len(self._responses): + resp = self._responses[self._call_count] + else: + resp = '{"score": 80}' + self._call_count += 1 + return MockMessage(resp) + + +@dataclass +class MockContext: + llm: Any = None + embedding_service: Any = None + opensearch_client: Any = None + cache: Any = None + tracer: Any = None + + +# ----------------------------------------------------------------------- +# Guardrail node +# ----------------------------------------------------------------------- + +class TestGuardrailNode: + def test_in_scope_query(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + ctx = MockContext(llm=MockLLM(['{"score": 85}'])) + state = {"query": "What does high HbA1c mean?"} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is True + assert result["guardrail_score"] == 85.0 + + def test_out_of_scope_query(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + ctx = MockContext(llm=MockLLM(['{"score": 10}'])) + state = {"query": "What is the weather today?"} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is False + assert result["routing_decision"] == "out_of_scope" + + def test_biomarkers_bypass(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + ctx = MockContext(llm=MockLLM()) + state = {"query": "analyze", "biomarkers": {"Glucose": 185}} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is True + assert result["guardrail_score"] == 95.0 + + def test_llm_failure_defaults_in_scope(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + broken_llm = MagicMock() + broken_llm.invoke.side_effect = Exception("LLM down") + ctx = MockContext(llm=broken_llm) + state = {"query": "What is HbA1c?"} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is True # benefit of the doubt + + +# ----------------------------------------------------------------------- +# Out-of-scope node +# ----------------------------------------------------------------------- + +class TestOutOfScopeNode: + def test_returns_rejection(self): + from src.services.agents.nodes.out_of_scope_node import out_of_scope_node + from src.services.agents.prompts import OUT_OF_SCOPE_RESPONSE + + ctx = MockContext() + result = out_of_scope_node({}, context=ctx) + assert result["final_answer"] == OUT_OF_SCOPE_RESPONSE + + +# ----------------------------------------------------------------------- +# Grade documents node +# ----------------------------------------------------------------------- + +class TestGradeDocumentsNode: + def test_grades_relevant(self): + from src.services.agents.nodes.grade_documents_node import grade_documents_node + + ctx = MockContext(llm=MockLLM(['{"relevant": true}', '{"relevant": false}'])) + state = { + "query": "diabetes treatment", + "retrieved_documents": [ + {"id": "1", "text": "Diabetes is treated with insulin."}, + {"id": "2", "text": "The weather is sunny today."}, + ], + } + result = grade_documents_node(state, context=ctx) + assert len(result["relevant_documents"]) == 1 + assert result["grading_results"][0]["relevant"] is True + assert result["grading_results"][1]["relevant"] is False + + def test_empty_docs_needs_rewrite(self): + from src.services.agents.nodes.grade_documents_node import grade_documents_node + + ctx = MockContext() + state = {"query": "test", "retrieved_documents": []} + result = grade_documents_node(state, context=ctx) + assert result["needs_rewrite"] is True + + +# ----------------------------------------------------------------------- +# Rewrite query node +# ----------------------------------------------------------------------- + +class TestRewriteQueryNode: + def test_rewrites(self): + from src.services.agents.nodes.rewrite_query_node import rewrite_query_node + + ctx = MockContext(llm=MockLLM(["diabetes HbA1c glucose management guidelines"])) + state = {"query": "sugar problems"} + result = rewrite_query_node(state, context=ctx) + assert "diabetes" in result["rewritten_query"].lower() or result["rewritten_query"] + + def test_llm_failure_keeps_original(self): + from src.services.agents.nodes.rewrite_query_node import rewrite_query_node + + broken_llm = MagicMock() + broken_llm.invoke.side_effect = Exception("timeout") + ctx = MockContext(llm=broken_llm) + state = {"query": "original query"} + result = rewrite_query_node(state, context=ctx) + assert result["rewritten_query"] == "original query" + + +# ----------------------------------------------------------------------- +# Generate answer node +# ----------------------------------------------------------------------- + +class TestGenerateAnswerNode: + def test_generates_answer(self): + from src.services.agents.nodes.generate_answer_node import generate_answer_node + + ctx = MockContext(llm=MockLLM(["Based on the evidence, HbA1c of 8.2% indicates poor glycemic control."])) + state = { + "query": "What does HbA1c 8.2 mean?", + "relevant_documents": [ + {"title": "Diabetes Guide", "section": "Diagnosis", "text": "HbA1c above 6.5% indicates diabetes."} + ], + } + result = generate_answer_node(state, context=ctx) + assert "final_answer" in result + assert len(result["final_answer"]) > 10 + + def test_llm_failure_returns_fallback(self): + from src.services.agents.nodes.generate_answer_node import generate_answer_node + + broken_llm = MagicMock() + broken_llm.invoke.side_effect = Exception("dead") + ctx = MockContext(llm=broken_llm) + state = {"query": "test", "relevant_documents": []} + result = generate_answer_node(state, context=ctx) + assert "apologize" in result["final_answer"].lower() + assert len(result["errors"]) > 0 + + +# ----------------------------------------------------------------------- +# Agentic RAG state +# ----------------------------------------------------------------------- + +class TestAgenticRAGState: + def test_state_is_typed_dict(self): + from src.services.agents.state import AgenticRAGState + # Should be usable as a dict type hint + state: AgenticRAGState = { + "query": "test", + "errors": [], + } + assert state["query"] == "test" diff --git a/tests/test_biomarker_service.py b/tests/test_biomarker_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e62d4c0a9564328a7b90a3f76854dea116ab18cf --- /dev/null +++ b/tests/test_biomarker_service.py @@ -0,0 +1,47 @@ +""" +Tests for src/services/biomarker/service.py — production biomarker validation. +""" + +import pytest + +from src.services.biomarker.service import BiomarkerService, ValidationReport + + +@pytest.fixture +def service(): + return BiomarkerService() + + +def test_validate_known_biomarkers(service: BiomarkerService): + """Should validate known biomarkers correctly.""" + report = service.validate({"Glucose": 185.0, "HbA1c": 8.2}) + assert isinstance(report, ValidationReport) + assert report.recognized_count >= 1 + # At least one result should exist + assert len(report.results) >= 1 + + +def test_validate_critical_generates_alert(service: BiomarkerService): + """Critically abnormal values should generate safety alerts.""" + # Glucose < 40 or > 500 should be critical + report = service.validate({"Glucose": 550.0}) + if report.recognized_count > 0: + critical = [r for r in report.results if r.status.startswith("CRITICAL")] + # If the validator flags it as critical, there should be alerts + if critical: + assert len(report.safety_alerts) > 0 + + +def test_validate_unrecognized(service: BiomarkerService): + """Unknown biomarker names should be listed as unrecognized.""" + report = service.validate({"FakeMarkerXYZ": 42.0}) + assert "FakeMarkerXYZ" in report.unrecognized + assert report.recognized_count == 0 + + +def test_list_supported(service: BiomarkerService): + """Should return a list of supported biomarkers.""" + supported = service.list_supported() + assert isinstance(supported, list) + # We know the validator has 24 biomarkers + assert len(supported) >= 20 diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..80078aee75f99efcd5042d647ff8fc2ce2a38559 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,27 @@ +""" +Tests for src/services/cache/redis_cache.py — graceful degradation. +""" + +import pytest + +from src.services.cache.redis_cache import RedisCache + + +class TestNullCache: + """When Redis is disabled, the NullCache should degrade gracefully.""" + + def test_null_cache_get_returns_none(self): + from src.services.cache.redis_cache import _NullCache + cache = _NullCache() + assert cache.get("anything") is None + + def test_null_cache_set_noop(self): + from src.services.cache.redis_cache import _NullCache + cache = _NullCache() + # Should not raise + cache.set("key", "value", ttl=10) + + def test_null_cache_delete_noop(self): + from src.services.cache.redis_cache import _NullCache + cache = _NullCache() + cache.delete("key") diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..3d09facf0c9c40ff71198557debe84703c214d01 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,43 @@ +""" +Tests for src/exceptions.py — domain exception hierarchy. +""" + +import pytest + +from src.exceptions import ( + AnalysisError, + BiomarkerError, + CacheError, + DatabaseError, + EmbeddingError, + GuardrailError, + LLMError, + MediGuardError, + ObservabilityError, + OllamaConnectionError, + OutOfScopeError, + PDFParsingError, + SearchError, + TelegramError, +) + + +def test_all_exceptions_inherit_from_root(): + """Every domain exception should inherit from MediGuardError.""" + for exc_cls in [ + DatabaseError, SearchError, EmbeddingError, PDFParsingError, + LLMError, OllamaConnectionError, BiomarkerError, AnalysisError, + GuardrailError, OutOfScopeError, CacheError, ObservabilityError, + TelegramError, + ]: + assert issubclass(exc_cls, MediGuardError), f"{exc_cls.__name__} must inherit MediGuardError" + + +def test_ollama_inherits_llm(): + assert issubclass(OllamaConnectionError, LLMError) + + +def test_exception_message(): + exc = SearchError("OpenSearch timeout") + assert str(exc) == "OpenSearch timeout" + assert isinstance(exc, MediGuardError) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..baf8089a5736322c4d3552663755247a4594f012 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,29 @@ +""" +Tests for src/models/analysis.py — SQLAlchemy ORM models. +""" + +from src.models.analysis import MedicalDocument, PatientAnalysis, SOPVersion + + +def test_patient_analysis_tablename(): + assert PatientAnalysis.__tablename__ == "patient_analyses" + + +def test_medical_document_tablename(): + assert MedicalDocument.__tablename__ == "medical_documents" + + +def test_sop_version_tablename(): + assert SOPVersion.__tablename__ == "sop_versions" + + +def test_patient_analysis_has_columns(): + cols = {c.name for c in PatientAnalysis.__table__.columns} + expected = {"id", "request_id", "biomarkers", "predicted_disease", "created_at"} + assert expected.issubset(cols) + + +def test_medical_document_has_columns(): + cols = {c.name for c in MedicalDocument.__table__.columns} + expected = {"id", "title", "content_hash", "parse_status", "created_at"} + assert expected.issubset(cols) diff --git a/tests/test_opensearch_config.py b/tests/test_opensearch_config.py new file mode 100644 index 0000000000000000000000000000000000000000..38f2eebda85400b6d4e860265dfe284a27ac75e8 --- /dev/null +++ b/tests/test_opensearch_config.py @@ -0,0 +1,32 @@ +""" +Tests for src/services/opensearch/index_config.py — OpenSearch mapping. +""" + +from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING + + +def test_mapping_has_required_fields(): + """The mapping should define all required fields.""" + props = MEDICAL_CHUNKS_MAPPING["mappings"]["properties"] + required = ["chunk_text", "title", "embedding", "biomarkers_mentioned", "condition_tags"] + for field_name in required: + assert field_name in props, f"Missing field: {field_name}" + + +def test_knn_vector_config(): + """The embedding field should be configured for KNN.""" + embed = MEDICAL_CHUNKS_MAPPING["mappings"]["properties"]["embedding"] + assert embed["type"] == "knn_vector" + assert embed["dimension"] == 1024 + + +def test_synonym_analyzer(): + """Mapping should include a medical synonym analyzer.""" + analyzers = MEDICAL_CHUNKS_MAPPING["settings"]["analysis"]["analyzer"] + assert "medical_analyzer" in analyzers + + +def test_knn_enabled(): + """KNN should be enabled in settings.""" + settings = MEDICAL_CHUNKS_MAPPING["settings"] + assert settings["index"]["knn"] is True diff --git a/tests/test_pdf_parser.py b/tests/test_pdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..872b055a46d23be3394b8b981803d9f774172ff4 --- /dev/null +++ b/tests/test_pdf_parser.py @@ -0,0 +1,42 @@ +""" +Tests for src/services/pdf_parser/service.py — PDF parsing. +""" + +from pathlib import Path + +import pytest + +from src.services.pdf_parser.service import PDFParserService, ParsedDocument + + +@pytest.fixture +def parser(): + return PDFParserService() + + +def test_missing_file(parser: PDFParserService): + """Should return error for missing files.""" + result = parser.parse(Path("/nonexistent/fake.pdf")) + assert isinstance(result, ParsedDocument) + assert result.error is not None + assert "not found" in result.error.lower() + + +def test_parse_directory_empty(parser: PDFParserService, tmp_path: Path): + """Empty directory should return empty list.""" + results = parser.parse_directory(tmp_path) + assert results == [] + + +def test_parse_directory_with_pdf(parser: PDFParserService, tmp_path: Path): + """Should parse PDFs found in a directory.""" + # Check if there are any real PDFs in data/medical_pdfs + pdf_dir = Path("data/medical_pdfs") + if pdf_dir.exists() and list(pdf_dir.glob("*.pdf")): + results = parser.parse_directory(pdf_dir) + assert len(results) > 0 + for doc in results: + assert isinstance(doc, ParsedDocument) + assert doc.filename.endswith(".pdf") + else: + pytest.skip("No medical PDFs available for testing") diff --git a/tests/test_production_api.py b/tests/test_production_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd8a70b892f2031e35e4abadd22f2d1526874f6 --- /dev/null +++ b/tests/test_production_api.py @@ -0,0 +1,85 @@ +""" +Tests for the production FastAPI app (src/main.py) — endpoint smoke tests. + +These tests use FastAPI's TestClient with mocked backing services +so they run without Docker infrastructure. +""" + +import pytest +from unittest.mock import MagicMock, patch + +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + """Create a test client with mocked services.""" + # We need to prevent the lifespan from actually connecting to services + with patch("src.main.lifespan") as mock_lifespan: + # Use a no-op lifespan + from contextlib import asynccontextmanager + + @asynccontextmanager + async def _noop_lifespan(app): + import time + app.state.start_time = time.time() + app.state.version = "2.0.0-test" + app.state.opensearch_client = None + app.state.embedding_service = None + app.state.cache = None + app.state.ollama_client = None + app.state.tracer = None + app.state.rag_service = None + app.state.ragbot_service = None + yield + + mock_lifespan.side_effect = _noop_lifespan + + from src.main import create_app + app = create_app() + app.router.lifespan_context = _noop_lifespan + with TestClient(app) as tc: + yield tc + + +def test_root(client: TestClient): + resp = client.get("/") + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "MediGuard AI" + assert "endpoints" in data + + +def test_health(client: TestClient): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "healthy" + assert "version" in data + + +def test_ask_no_service(client: TestClient): + """Without RAG service, /ask should return 503.""" + resp = client.post("/ask", json={"question": "What is diabetes?"}) + assert resp.status_code == 503 + + +def test_search_no_service(client: TestClient): + """Without OpenSearch, /search should return 503.""" + resp = client.post("/search", json={"query": "diabetes", "top_k": 5}) + assert resp.status_code == 503 + + +def test_analyze_no_service(client: TestClient): + """Without RagBot, /analyze/structured should return 503.""" + resp = client.post( + "/analyze/structured", + json={"biomarkers": {"Glucose": 185.0}}, + ) + assert resp.status_code == 503 + + +def test_validation_error(client: TestClient): + """Invalid request should return 422.""" + resp = client.post("/ask", json={"question": "ab"}) # too short + assert resp.status_code == 422 diff --git a/tests/test_prompts.py b/tests/test_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..a146b9c6d95deb78eaee0ec21fcdf7775d581b9f --- /dev/null +++ b/tests/test_prompts.py @@ -0,0 +1,38 @@ +""" +Tests for src/services/agents/prompts.py — prompt templates. +""" + +from src.services.agents.prompts import ( + GRADING_SYSTEM, + GUARDRAIL_SYSTEM, + OUT_OF_SCOPE_RESPONSE, + RAG_GENERATION_SYSTEM, + REWRITE_SYSTEM, +) + + +def test_guardrail_prompt_has_score(): + """Guardrail prompt should ask for a 0-100 score.""" + assert "score" in GUARDRAIL_SYSTEM.lower() + assert "0" in GUARDRAIL_SYSTEM + assert "100" in GUARDRAIL_SYSTEM + + +def test_grading_prompt_has_relevant(): + """Grading prompt should ask for relevant true/false.""" + assert "relevant" in GRADING_SYSTEM.lower() + + +def test_rag_generation_has_citation(): + """RAG generation prompt should mention citations.""" + assert "citation" in RAG_GENERATION_SYSTEM.lower() or "cite" in RAG_GENERATION_SYSTEM.lower() + + +def test_out_of_scope_is_polite(): + """Out-of-scope response should be informative and polite.""" + assert "medical" in OUT_OF_SCOPE_RESPONSE.lower() + assert len(OUT_OF_SCOPE_RESPONSE) > 50 + + +def test_rewrite_prompt_exists(): + assert len(REWRITE_SYSTEM) > 50 diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..fee3504b3ca6b31115a0ae7c00b1389905e3dc6a --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,87 @@ +""" +Tests for src/schemas/schemas.py — Pydantic request/response models. +""" + +import pytest +from pydantic import ValidationError + +from src.schemas.schemas import ( + AskRequest, + AskResponse, + HealthResponse, + NaturalAnalysisRequest, + SearchRequest, + SearchResponse, + StructuredAnalysisRequest, +) + + +class TestNaturalAnalysisRequest: + def test_valid(self): + req = NaturalAnalysisRequest(message="My glucose is 185 and HbA1c 8.2") + assert req.message == "My glucose is 185 and HbA1c 8.2" + + def test_too_short(self): + with pytest.raises(ValidationError): + NaturalAnalysisRequest(message="hi") + + +class TestStructuredAnalysisRequest: + def test_valid(self): + req = StructuredAnalysisRequest(biomarkers={"Glucose": 185.0}) + assert req.biomarkers["Glucose"] == 185.0 + + def test_empty_biomarkers(self): + with pytest.raises(ValidationError): + StructuredAnalysisRequest(biomarkers={}) + + +class TestAskRequest: + def test_valid(self): + req = AskRequest(question="What does high HbA1c mean?") + assert "HbA1c" in req.question + + def test_too_short(self): + with pytest.raises(ValidationError): + AskRequest(question="ab") + + def test_with_biomarkers(self): + req = AskRequest( + question="Explain my results", + biomarkers={"Glucose": 200.0}, + patient_context="52-year-old male", + ) + assert req.biomarkers is not None + + +class TestSearchRequest: + def test_defaults(self): + req = SearchRequest(query="diabetes guidelines") + assert req.top_k == 10 + assert req.mode == "hybrid" + + +class TestAskResponse: + def test_round_trip(self): + resp = AskResponse( + request_id="req_abc", + question="test?", + answer="test answer", + documents_retrieved=5, + documents_relevant=3, + processing_time_ms=123.4, + ) + data = resp.model_dump() + assert data["status"] == "success" + assert data["documents_relevant"] == 3 + + +class TestHealthResponse: + def test_basic(self): + resp = HealthResponse( + status="healthy", + timestamp="2025-01-01T00:00:00Z", + version="2.0.0", + uptime_seconds=42.0, + ) + assert resp.status == "healthy" diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..44276b66c410e0c50ecebfe796036e8b612233f6 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,44 @@ +""" +Tests for src/settings.py — Pydantic Settings hierarchy. +""" + +import os +from unittest.mock import patch + +import pytest + + +def test_settings_defaults(): + """Settings should have sensible defaults without env vars.""" + # Clear any cached instance + from src.settings import get_settings + get_settings.cache_clear() + + settings = get_settings() + assert settings.api.port == 8000 + assert "mediguard" in settings.postgres.database_url + assert "localhost" in settings.opensearch.host + assert settings.redis.port == 6379 + assert settings.ollama.model == "llama3.1:8b" + assert settings.embedding.dimension == 1024 + assert settings.chunking.chunk_size == 600 + + +def test_settings_frozen(): + """Settings should be immutable.""" + from src.settings import get_settings + get_settings.cache_clear() + + settings = get_settings() + with pytest.raises(Exception): + settings.api.port = 9999 + + +def test_settings_singleton(): + """get_settings should return the same cached instance.""" + from src.settings import get_settings + get_settings.cache_clear() + + s1 = get_settings() + s2 = get_settings() + assert s1 is s2 diff --git a/tests/test_text_chunker.py b/tests/test_text_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..6edafb3259da41503893681f32d771aecebad215 --- /dev/null +++ b/tests/test_text_chunker.py @@ -0,0 +1,69 @@ +""" +Tests for src/services/indexing/text_chunker.py — medical text chunking. +""" + +import pytest + +from src.services.indexing.text_chunker import MedicalChunk, MedicalTextChunker + + +@pytest.fixture +def chunker(): + return MedicalTextChunker(target_words=30, overlap_words=5, min_words=5) + + +def test_basic_chunking(chunker: MedicalTextChunker): + """Should split text into chunks.""" + # Generate enough words to require multiple chunks (target_words=30) + words = [f"word{i}" for i in range(200)] + text = " ".join(words) + chunks = chunker.chunk_text(text) + assert len(chunks) > 1 + for c in chunks: + assert isinstance(c, MedicalChunk) + assert c.text.strip() + + +def test_section_aware(chunker: MedicalTextChunker): + """Should detect section headers.""" + text = ( + "Introduction\nThis study examines diabetes.\n\n" + "Methods\nWe collected blood samples.\n\n" + "Results\nGlucose levels were elevated." + ) + chunks = chunker.chunk_text(text) + assert len(chunks) >= 1 + + +def test_biomarker_detection(chunker: MedicalTextChunker): + """Should detect biomarkers in chunks.""" + text = ( + "The patient's HbA1c was 8.2% indicating poor glycemic control. " + "Fasting glucose was 185 mg/dL and total cholesterol was elevated at 240." + ) + chunks = chunker.chunk_text(text) + assert len(chunks) >= 1 + # At least one chunk should have biomarkers detected + all_biomarkers = set() + for c in chunks: + all_biomarkers.update(c.biomarkers_mentioned) + assert len(all_biomarkers) > 0 + + +def test_condition_tagging(chunker: MedicalTextChunker): + """Should tag chunks with relevant conditions.""" + text = ( + "Diabetes mellitus is characterised by insulin resistance and elevated blood glucose. " + "Cardiovascular disease risk increases with uncontrolled hypertension." + ) + chunks = chunker.chunk_text(text) + all_tags = set() + for c in chunks: + all_tags.update(c.condition_tags) + assert "diabetes" in all_tags or "heart_disease" in all_tags + + +def test_empty_text(chunker: MedicalTextChunker): + """Empty text should return empty list.""" + assert chunker.chunk_text("") == [] + assert chunker.chunk_text(" ") == []