diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..44487f1e8406014b871e592245fab3aebb0ac300 --- /dev/null +++ b/.env.example @@ -0,0 +1,61 @@ +# =========================================================================== +# MediGuard AI β€” Environment Variables +# =========================================================================== +# Copy this file to .env and fill in your values. +# =========================================================================== + +# --- API --- +API__HOST=0.0.0.0 +API__PORT=8000 +API__DEBUG=true +CORS_ALLOWED_ORIGINS=* + +# --- PostgreSQL --- +POSTGRES__HOST=localhost +POSTGRES__PORT=5432 +POSTGRES__DATABASE=mediguard +POSTGRES__USER=mediguard +POSTGRES__PASSWORD=mediguard_secret + +# --- OpenSearch --- +OPENSEARCH__HOST=localhost +OPENSEARCH__PORT=9200 + +# --- Redis --- +REDIS__HOST=localhost +REDIS__PORT=6379 +REDIS__ENABLED=true + +# --- Ollama --- +OLLAMA__BASE_URL=http://localhost:11434 +OLLAMA__MODEL=llama3.2 + +# --- LLM (Groq / Gemini β€” existing providers) --- +LLM__PRIMARY_PROVIDER=groq +LLM__GROQ_API_KEY= +LLM__GROQ_MODEL=llama-3.3-70b-versatile +LLM__GEMINI_API_KEY= +LLM__GEMINI_MODEL=gemini-2.0-flash + +# --- Embeddings --- +EMBEDDING__PROVIDER=jina +EMBEDDING__JINA_API_KEY= +EMBEDDING__MODEL_NAME=jina-embeddings-v3 +EMBEDDING__DIMENSION=1024 + +# --- Langfuse --- +LANGFUSE__ENABLED=true +LANGFUSE__PUBLIC_KEY= +LANGFUSE__SECRET_KEY= +LANGFUSE__HOST=http://localhost:3000 + +# --- Chunking --- +CHUNKING__CHUNK_SIZE=1024 +CHUNKING__CHUNK_OVERLAP=128 + +# --- Telegram Bot (optional) --- +TELEGRAM__BOT_TOKEN= +TELEGRAM__API_BASE_URL=http://localhost:8000 + +# --- Medical PDFs --- +MEDICAL_PDFS__DIRECTORY=data/medical_pdfs diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a39f7f37b21cea5538f5ce50ec34ca19275eb7ea --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.faiss filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index aab7cfeb3390df42c87003444dcf621f64cab6ac..b3651429b10b39c109c3f035958645f20a3df6bb 100644 --- a/.gitignore +++ b/.gitignore @@ -221,10 +221,13 @@ $RECYCLE.BIN/ # Project Specific # ============================================================================== # Vector stores (large files, regenerate locally) +# BUT allow medical_knowledge for HuggingFace deployment data/vector_stores/*.faiss data/vector_stores/*.pkl -*.faiss -*.pkl +!data/vector_stores/medical_knowledge.faiss +!data/vector_stores/medical_knowledge.pkl +# *.faiss # Commented out to allow medical_knowledge +# *.pkl # Commented out to allow medical_knowledge # Medical PDFs (proprietary/large) data/medical_pdfs/*.pdf diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a2e852147524cfe46cfa5895d9abf8686701822 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +# MediGuard AI β€” Pre-commit hooks +# Install: pre-commit install +# Run all: pre-commit run --all-files + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-json + - id: check-merge-conflict + - id: detect-private-key + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.0 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.12.0 + hooks: + - id: mypy + additional_dependencies: [pydantic>=2.0] + args: [--ignore-missing-imports] diff --git a/DEPLOY_HUGGINGFACE.md b/DEPLOY_HUGGINGFACE.md new file mode 100644 index 0000000000000000000000000000000000000000..6dcd3354602c1221459ac900b6c1da6f1242d545 --- /dev/null +++ b/DEPLOY_HUGGINGFACE.md @@ -0,0 +1,203 @@ +# πŸš€ Deploy MediGuard AI to Hugging Face Spaces + +This guide walks you through deploying MediGuard AI to Hugging Face Spaces using Docker. + +## Prerequisites + +1. **Hugging Face Account** β€” [Sign up free](https://huggingface.co/join) +2. **Git** β€” Installed on your machine +3. **API Key** β€” Either: + - **Groq** (recommended) β€” [Get free key](https://console.groq.com/keys) + - **Google Gemini** β€” [Get free key](https://aistudio.google.com/app/apikey) + +## Step 1: Create a New Space + +1. Go to [huggingface.co/new-space](https://huggingface.co/new-space) +2. Fill in: + - **Space name**: `mediguard-ai` (or your choice) + - **License**: MIT + - **SDK**: Select **Docker** + - **Hardware**: **CPU Basic** (free tier works!) +3. Click **Create Space** + +## Step 2: Clone Your Space + +```bash +# Clone the empty space +git clone https://huggingface.co/spaces/YOUR_USERNAME/mediguard-ai +cd mediguard-ai +``` + +## Step 3: Copy Project Files + +Copy all files from this repository to your space folder: + +```bash +# Option A: If you have the RagBot repo locally +cp -r /path/to/RagBot/* . + +# Option B: Clone fresh +git clone https://github.com/yourusername/ragbot temp +cp -r temp/* . +rm -rf temp +``` + +## Step 4: Set Up Dockerfile for Spaces + +Hugging Face Spaces expects the Dockerfile in the root. Copy the HF-optimized Dockerfile: + +```bash +# Copy the HF Spaces Dockerfile to root +cp huggingface/Dockerfile ./Dockerfile +``` + +**Or** update your root `Dockerfile` to match the HF Spaces version. + +## Step 5: Set Up README (Important!) + +The README.md must have the HF Spaces metadata header. Copy the HF README: + +```bash +# Backup original README +mv README.md README_original.md + +# Use HF Spaces README +cp huggingface/README.md ./README.md +``` + +## Step 6: Add Your API Key (Secret) + +1. Go to your Space: `https://huggingface.co/spaces/YOUR_USERNAME/mediguard-ai` +2. Click **Settings** tab +3. Scroll to **Repository Secrets** +4. Add a new secret: + - **Name**: `GROQ_API_KEY` (or `GOOGLE_API_KEY`) + - **Value**: Your API key +5. Click **Add** + +## Step 7: Push to Deploy + +```bash +# Add all files +git add . + +# Commit +git commit -m "Deploy MediGuard AI" + +# Push to Hugging Face +git push +``` + +## Step 8: Monitor Deployment + +1. Go to your Space: `https://huggingface.co/spaces/YOUR_USERNAME/mediguard-ai` +2. Click the **Logs** tab to watch the build +3. Build takes ~5-10 minutes (first time) +4. Once "Running", your app is live! πŸŽ‰ + +## πŸ”§ Troubleshooting + +### "No LLM API key configured" + +- Make sure you added `GROQ_API_KEY` or `GOOGLE_API_KEY` in Space Settings β†’ Secrets +- Secret names are case-sensitive + +### Build fails with "No space disk" + +- Hugging Face free tier has limited disk space +- The FAISS vector store might be too large +- Solution: Upgrade to a paid tier or reduce vector store size + +### "ModuleNotFoundError" + +- Check that all dependencies are in `huggingface/requirements.txt` +- The Dockerfile should install from this file + +### App crashes on startup + +- Check Logs for the actual error +- Common issue: Missing environment variables +- Increase Space hardware if OOM error + +## πŸ“ File Structure for Deployment + +Your Space should have this structure: + +``` +your-space/ +β”œβ”€β”€ Dockerfile # HF Spaces Dockerfile (from huggingface/) +β”œβ”€β”€ README.md # HF Spaces README with metadata +β”œβ”€β”€ huggingface/ +β”‚ β”œβ”€β”€ app.py # Standalone Gradio app +β”‚ β”œβ”€β”€ requirements.txt # Minimal deps for HF +β”‚ └── README.md # Original HF README +β”œβ”€β”€ src/ # Core application code +β”‚ β”œβ”€β”€ workflow.py +β”‚ β”œβ”€β”€ state.py +β”‚ β”œβ”€β”€ llm_config.py +β”‚ β”œβ”€β”€ pdf_processor.py +β”‚ β”œβ”€β”€ agents/ +β”‚ └── ... +β”œβ”€β”€ data/ +β”‚ └── vector_stores/ +β”‚ β”œβ”€β”€ medical_knowledge.faiss +β”‚ └── medical_knowledge.pkl +└── config/ + └── biomarker_references.json +``` + +## πŸ”„ Updating Your Space + +To update after making changes: + +```bash +git add . +git commit -m "Update: description of changes" +git push +``` + +Hugging Face will automatically rebuild and redeploy. + +## πŸ’° Hardware Options + +| Tier | RAM | vCPU | Cost | Best For | +|------|-----|------|------|----------| +| CPU Basic | 2GB | 2 | Free | Demo/Testing | +| CPU Upgrade | 8GB | 4 | ~$0.03/hr | Production | +| T4 Small | 16GB | 4 | ~$0.06/hr | Heavy usage | + +The free tier works for demos. Upgrade if you experience timeouts. + +## πŸŽ‰ Your Space is Live! + +Once deployed, share your Space URL: + +``` +https://huggingface.co/spaces/YOUR_USERNAME/mediguard-ai +``` + +Anyone can now use MediGuard AI without any setup! + +--- + +## Quick Commands Reference + +```bash +# Clone your space +git clone https://huggingface.co/spaces/YOUR_USERNAME/mediguard-ai + +# Set up remote (if needed) +git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/mediguard-ai + +# Push changes +git push origin main + +# Force rebuild (if stuck) +# Go to Settings β†’ Factory Reset +``` + +## Need Help? + +- [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces) +- [Docker on Spaces](https://huggingface.co/docs/hub/spaces-sdks-docker) +- [Spaces Secrets](https://huggingface.co/docs/hub/spaces-secrets) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d7f3986a1b5f282a9519d7999cc2bcbb132aa898 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,66 @@ +# =========================================================================== +# MediGuard AI β€” Hugging Face Spaces Dockerfile +# =========================================================================== +# Optimized single-container deployment for Hugging Face Spaces. +# Uses FAISS vector store + Cloud LLMs (Groq/Gemini) - no external services. +# =========================================================================== + +FROM python:3.11-slim + +# Non-interactive apt +ENV DEBIAN_FRONTEND=noninteractive + +# Python settings +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +# HuggingFace Spaces runs on port 7860 +ENV GRADIO_SERVER_NAME="0.0.0.0" \ + GRADIO_SERVER_PORT=7860 + +# Default to HuggingFace embeddings (local, no API key needed) +ENV EMBEDDING_PROVIDER=huggingface + +WORKDIR /app + +# System dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first (cache layer) +COPY huggingface/requirements.txt ./requirements.txt +RUN pip install --upgrade pip && \ + pip install -r requirements.txt + +# Copy the entire project +COPY . . + +# Create necessary directories and ensure vector store exists +RUN mkdir -p data/medical_pdfs data/vector_stores data/chat_reports + +# Create non-root user (HF Spaces requirement) +RUN useradd -m -u 1000 user + +# Make app writable by user +RUN chown -R user:user /app + +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +WORKDIR /app + +EXPOSE 7860 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --retries=3 \ + CMD curl -sf http://localhost:7860/ || exit 1 + +# Launch Gradio app +CMD ["python", "huggingface/app.py"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..631f2ac0d6e59671117bf5ca337034456bc35977 --- /dev/null +++ b/Makefile @@ -0,0 +1,137 @@ +# =========================================================================== +# MediGuard AI β€” Makefile +# =========================================================================== +# Usage: +# make help β€” show all targets +# make setup β€” install deps + pre-commit hooks +# make dev β€” run API in dev mode with reload +# make test β€” run full test suite +# make lint β€” ruff check + mypy +# make docker-up β€” spin up all Docker services +# make docker-down β€” tear down Docker services +# =========================================================================== + +.DEFAULT_GOAL := help +SHELL := /bin/bash + +# Python / UV +PYTHON ?= python +UV ?= uv +PIP ?= pip + +# Docker +COMPOSE := docker compose + +# --------------------------------------------------------------------------- +# Help +# --------------------------------------------------------------------------- +.PHONY: help +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- +.PHONY: setup +setup: ## Install all deps (pip) + pre-commit hooks + $(PIP) install -e ".[all]" + pre-commit install + +.PHONY: setup-uv +setup-uv: ## Install all deps with UV + $(UV) pip install -e ".[all]" + pre-commit install + +# --------------------------------------------------------------------------- +# Development +# --------------------------------------------------------------------------- +.PHONY: dev +dev: ## Run API in dev mode (auto-reload) + uvicorn src.main:app --host 0.0.0.0 --port 8000 --reload + +.PHONY: gradio +gradio: ## Launch Gradio web UI + $(PYTHON) -m src.gradio_app + +.PHONY: telegram +telegram: ## Start Telegram bot + $(PYTHON) -c "from src.services.telegram.bot import MediGuardTelegramBot; MediGuardTelegramBot().run()" + +# --------------------------------------------------------------------------- +# Quality +# --------------------------------------------------------------------------- +.PHONY: lint +lint: ## Ruff check + MyPy + ruff check src/ tests/ + mypy src/ --ignore-missing-imports + +.PHONY: format +format: ## Ruff format + ruff format src/ tests/ + ruff check --fix src/ tests/ + +.PHONY: test +test: ## Run pytest with coverage + pytest tests/ -v --tb=short --cov=src --cov-report=term-missing + +.PHONY: test-quick +test-quick: ## Run only fast unit tests + pytest tests/ -v --tb=short -m "not slow" + +# --------------------------------------------------------------------------- +# Docker +# --------------------------------------------------------------------------- +.PHONY: docker-up +docker-up: ## Start all Docker services (detached) + $(COMPOSE) up -d + +.PHONY: docker-down +docker-down: ## Stop and remove Docker services + $(COMPOSE) down -v + +.PHONY: docker-build +docker-build: ## Build Docker images + $(COMPOSE) build + +.PHONY: docker-logs +docker-logs: ## Tail Docker logs + $(COMPOSE) logs -f + +# --------------------------------------------------------------------------- +# Database +# --------------------------------------------------------------------------- +.PHONY: db-upgrade +db-upgrade: ## Run Alembic migrations + alembic upgrade head + +.PHONY: db-revision +db-revision: ## Create a new Alembic migration + alembic revision --autogenerate -m "$(msg)" + +# --------------------------------------------------------------------------- +# Indexing +# --------------------------------------------------------------------------- +.PHONY: index-pdfs +index-pdfs: ## Parse and index all medical PDFs + $(PYTHON) -c "\ +from pathlib import Path; \ +from src.services.pdf_parser.service import make_pdf_parser_service; \ +from src.services.indexing.service import IndexingService; \ +from src.services.embeddings.service import make_embedding_service; \ +from src.services.opensearch.client import make_opensearch_client; \ +parser = make_pdf_parser_service(); \ +idx = IndexingService(make_embedding_service(), make_opensearch_client()); \ +docs = parser.parse_directory(Path('data/medical_pdfs')); \ +[idx.index_text(d.full_text, {'title': d.filename}) for d in docs if d.full_text]; \ +print(f'Indexed {len(docs)} documents')" + +# --------------------------------------------------------------------------- +# Clean +# --------------------------------------------------------------------------- +.PHONY: clean +clean: ## Remove build artifacts and caches + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .pytest_cache -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .mypy_cache -exec rm -rf {} + 2>/dev/null || true + find . -type d -name .ruff_cache -exec rm -rf {} + 2>/dev/null || true + rm -rf dist/ build/ *.egg-info diff --git a/README.md b/README.md index 8afa76196b53c4047d8bc1dac3699c8aa2242a50..4c32163f3524e530799e34e0895039182010eeec 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,22 @@ +--- +title: Agentic RagBot +emoji: πŸ₯ +colorFrom: blue +colorTo: indigo +sdk: docker +pinned: true +license: mit +app_port: 7860 +tags: + - medical + - biomarker + - rag + - healthcare + - langgraph + - agents +short_description: Multi-Agent RAG System for Medical Biomarker Analysis +--- + # RagBot: Multi-Agent RAG System for Medical Biomarker Analysis A production-ready biomarker analysis system combining 6 specialized AI agents with medical knowledge retrieval to provide evidence-based insights on blood test results in **15-25 seconds**. diff --git a/airflow/dags/ingest_pdfs.py b/airflow/dags/ingest_pdfs.py new file mode 100644 index 0000000000000000000000000000000000000000..07c9fc9f19c743de4233a583e4c61f0a28bf5d7d --- /dev/null +++ b/airflow/dags/ingest_pdfs.py @@ -0,0 +1,64 @@ +""" +MediGuard AI β€” Airflow DAG: Ingest Medical PDFs + +Periodically scans the medical_pdfs directory, parses new PDFs, +chunks them, generates embeddings, and indexes into OpenSearch. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.python import PythonOperator + +default_args = { + "owner": "mediguard", + "retries": 2, + "retry_delay": timedelta(minutes=5), + "email_on_failure": False, +} + + +def _ingest_pdfs(**kwargs): + """Parse all PDFs and index into OpenSearch.""" + from pathlib import Path + + from src.services.embeddings.service import make_embedding_service + from src.services.indexing.service import IndexingService + from src.services.opensearch.client import make_opensearch_client + from src.services.pdf_parser.service import make_pdf_parser_service + from src.settings import get_settings + + settings = get_settings() + pdf_dir = Path(settings.medical_pdfs.directory) + + parser = make_pdf_parser_service() + embedding_svc = make_embedding_service() + os_client = make_opensearch_client() + indexing_svc = IndexingService(embedding_svc, os_client) + + docs = parser.parse_directory(pdf_dir) + indexed = 0 + for doc in docs: + if doc.full_text and not doc.error: + indexing_svc.index_text(doc.full_text, {"title": doc.filename}) + indexed += 1 + + print(f"Ingested {indexed}/{len(docs)} documents") + return {"total": len(docs), "indexed": indexed} + + +with DAG( + dag_id="mediguard_ingest_pdfs", + default_args=default_args, + description="Parse and index medical PDFs into OpenSearch", + schedule="@daily", + start_date=datetime(2025, 1, 1), + catchup=False, + tags=["mediguard", "indexing"], +) as dag: + ingest = PythonOperator( + task_id="ingest_medical_pdfs", + python_callable=_ingest_pdfs, + ) diff --git a/airflow/dags/sop_evolution.py b/airflow/dags/sop_evolution.py new file mode 100644 index 0000000000000000000000000000000000000000..31e20d2901ced5c90805f1c0faac8d4dd5312e58 --- /dev/null +++ b/airflow/dags/sop_evolution.py @@ -0,0 +1,43 @@ +""" +MediGuard AI β€” Airflow DAG: SOP Evolution Cycle + +Runs the evolutionary SOP optimisation loop periodically. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.python import PythonOperator + +default_args = { + "owner": "mediguard", + "retries": 1, + "retry_delay": timedelta(minutes=10), + "email_on_failure": False, +} + + +def _run_evolution(**kwargs): + """Execute one SOP evolution cycle.""" + from src.evolution.director import run_evolution_cycle + + result = run_evolution_cycle() + print(f"Evolution cycle complete: {result}") + return result + + +with DAG( + dag_id="mediguard_sop_evolution", + default_args=default_args, + description="Run SOP evolutionary optimisation", + schedule="@weekly", + start_date=datetime(2025, 1, 1), + catchup=False, + tags=["mediguard", "evolution"], +) as dag: + evolve = PythonOperator( + task_id="run_sop_evolution", + python_callable=_run_evolution, + ) diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000000000000000000000000000000000000..807ded2d51659e78e03f0078187d54724b7658d8 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,149 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s +# Or organize into date-based subdirectories (requires recursive_version_locations = true) +# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000000000000000000000000000000000000..98e4f9c44effe479ed38c66ba922e7bcc672916f --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000000000000000000000000000000000000..e727637dc6583bf33d7cfbd3bd21c084c4af310e --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,95 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool, create_engine + +from alembic import context + +# --------------------------------------------------------------------------- +# MediGuard AI β€” Alembic env.py +# Pull DB URL from settings so we never hard-code credentials. +# --------------------------------------------------------------------------- +import sys +import os + +# Make sure the project root is on sys.path +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.settings import get_settings # noqa: E402 +from src.database import Base # noqa: E402 + +# Import all models so Alembic's autogenerate can see them +import src.models.analysis # noqa: F401, E402 + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Override sqlalchemy.url from our Pydantic Settings +_settings = get_settings() +config.set_main_option("sqlalchemy.url", _settings.postgres.database_url) + +# Metadata used for autogenerate +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000000000000000000000000000000000000..11016301e749297acb67822efc7974ee53c905c6 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/data/vector_stores/medical_knowledge.faiss b/data/vector_stores/medical_knowledge.faiss new file mode 100644 index 0000000000000000000000000000000000000000..c59312e0a3defe88767fce7db4a7e4f0b9692f7f --- /dev/null +++ b/data/vector_stores/medical_knowledge.faiss @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9dee84846c00eda0f0a5487b61c2dd9cc85588ee0cbbcb576df24e8881969e1 +size 4007469 diff --git a/data/vector_stores/medical_knowledge.pkl b/data/vector_stores/medical_knowledge.pkl new file mode 100644 index 0000000000000000000000000000000000000000..26ee39caecefed3f6eefb346ca8a99fcdfd611de --- /dev/null +++ b/data/vector_stores/medical_knowledge.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:690fa693a48c3eb5e0a1fc11b7008a9037630928d9c8a634a31e7f90d8e2f7fb +size 2727206 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..aac9873bc7d779eda6ef82fc5e8991ab45c79d68 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,168 @@ +# =========================================================================== +# MediGuard AI β€” Docker Compose (development / CI) +# =========================================================================== +# Usage: +# docker compose up -d β€” start all services +# docker compose down -v β€” stop and remove volumes +# docker compose logs -f api β€” follow API logs +# =========================================================================== + +services: + # ----------------------------------------------------------------------- + # Application + # ----------------------------------------------------------------------- + api: + build: + context: . + dockerfile: Dockerfile + target: production + container_name: mediguard-api + ports: + - "${API_PORT:-8000}:8000" + env_file: .env + environment: + - POSTGRES__HOST=postgres + - OPENSEARCH__HOST=opensearch + - OPENSEARCH__PORT=9200 + - REDIS__HOST=redis + - REDIS__PORT=6379 + - OLLAMA__BASE_URL=http://ollama:11434 + - LANGFUSE__HOST=http://langfuse:3000 + depends_on: + postgres: + condition: service_healthy + opensearch: + condition: service_healthy + redis: + condition: service_healthy + volumes: + - ./data/medical_pdfs:/app/data/medical_pdfs:ro + restart: unless-stopped + + gradio: + build: + context: . + dockerfile: Dockerfile + target: production + container_name: mediguard-gradio + command: python -m src.gradio_app + ports: + - "${GRADIO_PORT:-7860}:7860" + environment: + - MEDIGUARD_API_URL=http://api:8000 + depends_on: + - api + restart: unless-stopped + + # ----------------------------------------------------------------------- + # Backing services + # ----------------------------------------------------------------------- + postgres: + image: postgres:16-alpine + container_name: mediguard-postgres + environment: + POSTGRES_DB: ${POSTGRES__DATABASE:-mediguard} + POSTGRES_USER: ${POSTGRES__USER:-mediguard} + POSTGRES_PASSWORD: ${POSTGRES__PASSWORD:-mediguard_secret} + ports: + - "${POSTGRES_PORT:-5432}:5432" + volumes: + - pg_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U mediguard"] + interval: 5s + timeout: 3s + retries: 10 + restart: unless-stopped + + opensearch: + image: opensearchproject/opensearch:2.11.1 + container_name: mediguard-opensearch + environment: + - discovery.type=single-node + - DISABLE_SECURITY_PLUGIN=true + - plugins.security.disabled=true + - "OPENSEARCH_JAVA_OPTS=-Xms256m -Xmx256m" + - bootstrap.memory_lock=true + ulimits: + memlock: { soft: -1, hard: -1 } + nofile: { soft: 65536, hard: 65536 } + ports: + - "${OPENSEARCH_PORT:-9200}:9200" + volumes: + - os_data:/usr/share/opensearch/data + healthcheck: + test: ["CMD-SHELL", "curl -sf http://localhost:9200/_cluster/health || exit 1"] + interval: 10s + timeout: 5s + retries: 24 + restart: unless-stopped + + # opensearch-dashboards: disabled by default β€” uncomment if you need the UI + # opensearch-dashboards: + # image: opensearchproject/opensearch-dashboards:2.11.1 + # container_name: mediguard-os-dashboards + # environment: + # - OPENSEARCH_HOSTS=["http://opensearch:9200"] + # - DISABLE_SECURITY_DASHBOARDS_PLUGIN=true + # ports: + # - "${OS_DASHBOARDS_PORT:-5601}:5601" + # depends_on: + # opensearch: + # condition: service_healthy + # restart: unless-stopped + + redis: + image: redis:7-alpine + container_name: mediguard-redis + ports: + - "${REDIS_PORT:-6379}:6379" + volumes: + - redis_data:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 3s + retries: 10 + restart: unless-stopped + + ollama: + image: ollama/ollama:latest + container_name: mediguard-ollama + ports: + - "${OLLAMA_PORT:-11434}:11434" + volumes: + - ollama_data:/root/.ollama + restart: unless-stopped + # Uncomment for GPU support: + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + + # ----------------------------------------------------------------------- + # Observability + # ----------------------------------------------------------------------- + langfuse: + image: langfuse/langfuse:2 + container_name: mediguard-langfuse + environment: + - DATABASE_URL=postgresql://mediguard:mediguard_secret@postgres:5432/langfuse + - NEXTAUTH_URL=http://localhost:3000 + - NEXTAUTH_SECRET=mediguard-langfuse-secret-change-me + - SALT=mediguard-langfuse-salt-change-me + ports: + - "${LANGFUSE_PORT:-3000}:3000" + depends_on: + postgres: + condition: service_healthy + restart: unless-stopped + +volumes: + pg_data: + os_data: + redis_data: + ollama_data: diff --git a/huggingface/.env.example b/huggingface/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..a08ada8cfa775ff5afd7ff00671654708ba7f9f3 --- /dev/null +++ b/huggingface/.env.example @@ -0,0 +1,21 @@ +# =========================================================================== +# MediGuard AI β€” HuggingFace Spaces Environment Variables +# =========================================================================== +# MINIMAL config for HuggingFace Spaces deployment. +# Only the LLM API key is required β€” everything else has sensible defaults. +# =========================================================================== + +# --- LLM Provider (choose ONE) --- +# Option 1: Groq (RECOMMENDED - fast, free) +GROQ_API_KEY=your_groq_api_key_here + +# Option 2: Google Gemini (alternative free option) +# GOOGLE_API_KEY=your_google_api_key_here + +# --- Provider Selection (auto-detected from keys) --- +LLM_PROVIDER=groq + +# --- Embedding Provider (must match vector store) --- +# The bundled vector store uses HuggingFace embeddings (384 dim) +# DO NOT CHANGE THIS unless you rebuild the vector store! +EMBEDDING_PROVIDER=huggingface diff --git a/huggingface/Dockerfile b/huggingface/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d7f3986a1b5f282a9519d7999cc2bcbb132aa898 --- /dev/null +++ b/huggingface/Dockerfile @@ -0,0 +1,66 @@ +# =========================================================================== +# MediGuard AI β€” Hugging Face Spaces Dockerfile +# =========================================================================== +# Optimized single-container deployment for Hugging Face Spaces. +# Uses FAISS vector store + Cloud LLMs (Groq/Gemini) - no external services. +# =========================================================================== + +FROM python:3.11-slim + +# Non-interactive apt +ENV DEBIAN_FRONTEND=noninteractive + +# Python settings +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +# HuggingFace Spaces runs on port 7860 +ENV GRADIO_SERVER_NAME="0.0.0.0" \ + GRADIO_SERVER_PORT=7860 + +# Default to HuggingFace embeddings (local, no API key needed) +ENV EMBEDDING_PROVIDER=huggingface + +WORKDIR /app + +# System dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first (cache layer) +COPY huggingface/requirements.txt ./requirements.txt +RUN pip install --upgrade pip && \ + pip install -r requirements.txt + +# Copy the entire project +COPY . . + +# Create necessary directories and ensure vector store exists +RUN mkdir -p data/medical_pdfs data/vector_stores data/chat_reports + +# Create non-root user (HF Spaces requirement) +RUN useradd -m -u 1000 user + +# Make app writable by user +RUN chown -R user:user /app + +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +WORKDIR /app + +EXPOSE 7860 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --retries=3 \ + CMD curl -sf http://localhost:7860/ || exit 1 + +# Launch Gradio app +CMD ["python", "huggingface/app.py"] diff --git a/huggingface/README.md b/huggingface/README.md new file mode 100644 index 0000000000000000000000000000000000000000..90759d99c03fbcf020b15deacc6e2922b3d61780 --- /dev/null +++ b/huggingface/README.md @@ -0,0 +1,109 @@ +--- +title: Agentic RagBot +emoji: πŸ₯ +colorFrom: blue +colorTo: indigo +sdk: docker +pinned: true +license: mit +app_port: 7860 +tags: + - medical + - biomarker + - rag + - healthcare + - langgraph + - agents +short_description: Multi-Agent RAG System for Medical Biomarker Analysis +--- + +# πŸ₯ MediGuard AI β€” Medical Biomarker Analysis + +A production-ready **Multi-Agent RAG System** that analyzes blood test biomarkers using 6 specialized AI agents with medical knowledge retrieval. + +## ✨ Features + +- **6 Specialist AI Agents** β€” Biomarker validation, disease prediction, RAG-powered analysis, confidence assessment +- **Medical Knowledge Base** β€” 750+ pages of clinical guidelines (FAISS vector store) +- **Evidence-Based** β€” All recommendations backed by retrieved medical literature +- **Free Cloud LLMs** β€” Uses Groq (LLaMA 3.3-70B) or Google Gemini + +## πŸš€ Quick Start + +1. **Enter your biomarkers** in any format: + - `Glucose: 140, HbA1c: 7.5` + - `My glucose is 140 and HbA1c is 7.5` + - `{"Glucose": 140, "HbA1c": 7.5}` + +2. **Click Analyze** and get: + - Primary diagnosis with confidence score + - Critical alerts and safety flags + - Biomarker analysis with normal ranges + - Evidence-based recommendations + - Disease pathophysiology explanation + +## πŸ”§ Configuration + +This Space requires an LLM API key. Add one of these secrets in Space Settings: + +| Secret | Provider | Get Free Key | +|--------|----------|--------------| +| `GROQ_API_KEY` | Groq (recommended) | [console.groq.com/keys](https://console.groq.com/keys) | +| `GOOGLE_API_KEY` | Google Gemini | [aistudio.google.com](https://aistudio.google.com/app/apikey) | + +## πŸ—οΈ Architecture + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Clinical Insight Guild β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ 1. Biomarker Analyzer β”‚ β”‚ +β”‚ β”‚ Validates values, flags abnormalities β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β–Ό β–Ό β–Ό β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ Disease β”‚ β”‚Biomarker β”‚ β”‚ Clinical β”‚ β”‚ +β”‚ β”‚Explainer β”‚ β”‚ Linker β”‚ β”‚Guidelinesβ”‚ β”‚ +β”‚ β”‚ (RAG) β”‚ β”‚ β”‚ β”‚ (RAG) β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β–Ό β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ 4. Confidence Assessor β”‚ β”‚ +β”‚ β”‚ Evaluates reliability, assigns scores β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β–Ό β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ 5. Response Synthesizer β”‚ β”‚ +β”‚ β”‚ Compiles patient-friendly summary β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +## πŸ“Š Supported Biomarkers + +| Category | Biomarkers | +|----------|------------| +| **Diabetes** | Glucose, HbA1c, Fasting Glucose, Insulin | +| **Lipids** | Cholesterol, LDL, HDL, Triglycerides | +| **Kidney** | Creatinine, BUN, eGFR | +| **Liver** | ALT, AST, Bilirubin, Albumin | +| **Thyroid** | TSH, T3, T4, Free T4 | +| **Blood** | Hemoglobin, WBC, RBC, Platelets | +| **Cardiac** | Troponin, BNP, CRP | + +## ⚠️ Medical Disclaimer + +This tool is for **informational purposes only** and does not replace professional medical advice, diagnosis, or treatment. Always consult a qualified healthcare provider with questions regarding a medical condition. + +## πŸ“„ License + +MIT License β€” See [GitHub Repository](https://github.com/yourusername/ragbot) for details. + +## πŸ™ Acknowledgments + +Built with [LangGraph](https://langchain-ai.github.io/langgraph/), [FAISS](https://faiss.ai/), [Gradio](https://gradio.app/), and [Groq](https://groq.com/). diff --git a/huggingface/app.py b/huggingface/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4613f819e2fa9faef2fa5e023e37e00a3c7f7a06 --- /dev/null +++ b/huggingface/app.py @@ -0,0 +1,1025 @@ +""" +MediGuard AI β€” Hugging Face Spaces Gradio App + +Standalone deployment that uses: +- FAISS vector store (local) +- Cloud LLMs (Groq or Gemini - FREE tiers) +- No external services required +""" + +from __future__ import annotations + +import json +import logging +import os +import sys +import time +import traceback +from pathlib import Path +from typing import Any, Optional + +# Ensure project root is in path +_project_root = str(Path(__file__).parent.parent) +if _project_root not in sys.path: + sys.path.insert(0, _project_root) +os.chdir(_project_root) + +import gradio as gr + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(name)-20s | %(levelname)-7s | %(message)s", +) +logger = logging.getLogger("mediguard.huggingface") + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +def get_api_keys(): + """Get API keys dynamically (HuggingFace injects secrets after module load).""" + groq_key = os.getenv("GROQ_API_KEY", "") + google_key = os.getenv("GOOGLE_API_KEY", "") + return groq_key, google_key + + +def setup_llm_provider(): + """Set LLM provider based on available keys.""" + groq_key, google_key = get_api_keys() + + if groq_key: + os.environ["LLM_PROVIDER"] = "groq" + os.environ["GROQ_API_KEY"] = groq_key # Ensure it's set + return "groq" + elif google_key: + os.environ["LLM_PROVIDER"] = "gemini" + os.environ["GOOGLE_API_KEY"] = google_key + return "gemini" + return None + + +# Log status at startup (keys may not be available yet) +_groq, _google = get_api_keys() +if not _groq and not _google: + logger.warning( + "No LLM API key found at startup. Will check again when analyzing." + ) + + +# --------------------------------------------------------------------------- +# Guild Initialization (lazy) +# --------------------------------------------------------------------------- + +_guild = None +_guild_error = None +_guild_provider = None # Track which provider was used + + +def reset_guild(): + """Reset guild to force re-initialization (e.g., when API key changes).""" + global _guild, _guild_error, _guild_provider + _guild = None + _guild_error = None + _guild_provider = None + + +def get_guild(): + """Lazy initialization of the Clinical Insight Guild.""" + global _guild, _guild_error, _guild_provider + + # Check if we need to reinitialize (provider changed) + current_provider = os.getenv("LLM_PROVIDER") + if _guild_provider and _guild_provider != current_provider: + logger.info(f"Provider changed from {_guild_provider} to {current_provider}, reinitializing...") + reset_guild() + + if _guild is not None: + return _guild + + if _guild_error is not None: + # Don't cache errors forever - allow retry + logger.warning("Previous initialization failed, retrying...") + _guild_error = None + + try: + logger.info("Initializing Clinical Insight Guild...") + logger.info(f"LLM_PROVIDER={os.getenv('LLM_PROVIDER')}") + logger.info(f"GROQ_API_KEY={'set' if os.getenv('GROQ_API_KEY') else 'NOT SET'}") + logger.info(f"GOOGLE_API_KEY={'set' if os.getenv('GOOGLE_API_KEY') else 'NOT SET'}") + + start = time.time() + + from src.workflow import create_guild + _guild = create_guild() + _guild_provider = current_provider + + elapsed = time.time() - start + logger.info(f"Guild initialized in {elapsed:.1f}s") + return _guild + + except Exception as exc: + logger.error(f"Failed to initialize guild: {exc}") + _guild_error = exc + raise + + +# --------------------------------------------------------------------------- +# Analysis Functions +# --------------------------------------------------------------------------- + +def parse_biomarkers(text: str) -> dict[str, float]: + """ + Parse biomarkers from natural language text. + + Supports formats like: + - "Glucose: 140, HbA1c: 7.5" + - "glucose 140 hba1c 7.5" + - {"Glucose": 140, "HbA1c": 7.5} + """ + text = text.strip() + + # Try JSON first + if text.startswith("{"): + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Parse natural language + import re + + # Common biomarker patterns + patterns = [ + # "Glucose: 140" or "Glucose = 140" + r"([A-Za-z0-9_]+)\s*[:=]\s*([\d.]+)", + # "Glucose 140 mg/dL" + r"([A-Za-z0-9_]+)\s+([\d.]+)\s*(?:mg/dL|mmol/L|%|g/dL|U/L|mIU/L)?", + ] + + biomarkers = {} + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + for name, value in matches: + try: + biomarkers[name.strip()] = float(value) + except ValueError: + continue + + return biomarkers + + +def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, str, str]: + """ + Analyze biomarkers using the Clinical Insight Guild. + + Returns: (summary, details_json, status) + """ + if not input_text.strip(): + return "", "", """ +
+ ✍️ +

Please enter biomarkers to analyze.

+
+ """ + + # Check API key dynamically (HF injects secrets after startup) + groq_key, google_key = get_api_keys() + + if not groq_key and not google_key: + return "", "", """ +
+ ❌ No API Key Configured +

Please add your API key in Space Settings β†’ Secrets:

+ +
+ """ + + # Setup provider based on available key + provider = setup_llm_provider() + logger.info(f"Using LLM provider: {provider}") + + try: + progress(0.1, desc="πŸ“ Parsing biomarkers...") + biomarkers = parse_biomarkers(input_text) + + if not biomarkers: + return "", "", """ +
+ ⚠️ Could not parse biomarkers +

Try formats like:

+ +
+ """ + + progress(0.2, desc="πŸ”§ Initializing AI agents...") + + # Initialize guild + guild = get_guild() + + # Prepare input + from src.state import PatientInput + + # Auto-generate prediction based on common patterns + prediction = auto_predict(biomarkers) + + patient_input = PatientInput( + biomarkers=biomarkers, + model_prediction=prediction, + patient_context={"patient_id": "HF_User", "source": "huggingface_spaces"} + ) + + progress(0.4, desc="πŸ€– Running Clinical Insight Guild...") + + # Run analysis + start = time.time() + result = guild.run(patient_input) + elapsed = time.time() - start + + progress(0.9, desc="✨ Formatting results...") + + # Extract response + final_response = result.get("final_response", {}) + + # Format summary + summary = format_summary(final_response, elapsed) + + # Format details + details = json.dumps(final_response, indent=2, default=str) + + status = f""" +
+ βœ… +
+ Analysis Complete + ({elapsed:.1f}s) +
+
+ """ + + return summary, details, status + + except Exception as exc: + logger.error(f"Analysis error: {exc}", exc_info=True) + error_msg = f""" +
+ ❌ Analysis Error +

{exc}

+
+ Show details +
{traceback.format_exc()}
+
+
+ """ + return "", "", error_msg + + +def auto_predict(biomarkers: dict[str, float]) -> dict[str, Any]: + """ + Auto-generate a disease prediction based on biomarkers. + This simulates what an ML model would provide. + """ + # Normalize biomarker names for matching + normalized = {k.lower().replace(" ", ""): v for k, v in biomarkers.items()} + + # Check for diabetes indicators + glucose = normalized.get("glucose", normalized.get("fastingglucose", 0)) + hba1c = normalized.get("hba1c", normalized.get("hemoglobina1c", 0)) + + if hba1c >= 6.5 or glucose >= 126: + return { + "disease": "Diabetes", + "confidence": min(0.95, 0.7 + (hba1c - 6.5) * 0.1) if hba1c else 0.85, + "severity": "high" if hba1c >= 8 or glucose >= 200 else "moderate" + } + + # Check for lipid disorders + cholesterol = normalized.get("cholesterol", normalized.get("totalcholesterol", 0)) + ldl = normalized.get("ldl", normalized.get("ldlcholesterol", 0)) + triglycerides = normalized.get("triglycerides", 0) + + if cholesterol >= 240 or ldl >= 160 or triglycerides >= 200: + return { + "disease": "Dyslipidemia", + "confidence": 0.85, + "severity": "moderate" + } + + # Check for anemia + hemoglobin = normalized.get("hemoglobin", normalized.get("hgb", normalized.get("hb", 0))) + + if hemoglobin and hemoglobin < 12: + return { + "disease": "Anemia", + "confidence": 0.80, + "severity": "moderate" + } + + # Check for thyroid issues + tsh = normalized.get("tsh", 0) + + if tsh > 4.5: + return { + "disease": "Hypothyroidism", + "confidence": 0.75, + "severity": "moderate" + } + elif tsh and tsh < 0.4: + return { + "disease": "Hyperthyroidism", + "confidence": 0.75, + "severity": "moderate" + } + + # Default - general health screening + return { + "disease": "General Health Screening", + "confidence": 0.70, + "severity": "low" + } + + +def format_summary(response: dict, elapsed: float) -> str: + """Format the analysis response as beautiful HTML/markdown.""" + if not response: + return """ +
+
❌
+

No analysis results available.

+
+ """ + + parts = [] + + # Header with primary finding and confidence + primary = response.get("primary_finding", "Analysis Complete") + confidence = response.get("confidence", {}) + conf_score = confidence.get("overall_score", 0) if isinstance(confidence, dict) else 0 + + # Determine severity color + severity = response.get("severity", "low") + severity_colors = { + "critical": ("#dc2626", "#fef2f2", "πŸ”΄"), + "high": ("#ea580c", "#fff7ed", "🟠"), + "moderate": ("#ca8a04", "#fefce8", "🟑"), + "low": ("#16a34a", "#f0fdf4", "🟒") + } + color, bg_color, emoji = severity_colors.get(severity, severity_colors["low"]) + + # Confidence badge + conf_badge = "" + if conf_score: + conf_pct = int(conf_score * 100) + conf_color = "#16a34a" if conf_pct >= 80 else "#ca8a04" if conf_pct >= 60 else "#dc2626" + conf_badge = f'{conf_pct}% confidence' + + parts.append(f""" +
+
+ {emoji} +

{primary}

+ {conf_badge} +
+
+ """) + + # Critical Alerts + alerts = response.get("safety_alerts", []) + if alerts: + alert_items = "" + for alert in alerts[:5]: + if isinstance(alert, dict): + alert_items += f'
  • {alert.get("alert_type", "Alert")}: {alert.get("message", "")}
  • ' + else: + alert_items += f'
  • {alert}
  • ' + + parts.append(f""" +
    +

    + ⚠️ Critical Alerts +

    + +
    + """) + + # Key Findings + findings = response.get("key_findings", []) + if findings: + finding_items = "".join([f'
  • {f}
  • ' for f in findings[:5]]) + parts.append(f""" +
    +

    πŸ” Key Findings

    + +
    + """) + + # Biomarker Flags - as a visual grid + flags = response.get("biomarker_flags", []) + if flags: + flag_cards = "" + for flag in flags[:8]: + if isinstance(flag, dict): + name = flag.get("biomarker", "Unknown") + status = flag.get("status", "normal") + value = flag.get("value", "N/A") + + status_styles = { + "critical": ("πŸ”΄", "#dc2626", "#fef2f2"), + "abnormal": ("🟑", "#ca8a04", "#fefce8"), + "normal": ("🟒", "#16a34a", "#f0fdf4") + } + s_emoji, s_color, s_bg = status_styles.get(status, status_styles["normal"]) + + flag_cards += f""" +
    +
    {s_emoji}
    +
    {name}
    +
    {value}
    +
    {status}
    +
    + """ + + parts.append(f""" +
    +

    πŸ“Š Biomarker Analysis

    +
    + {flag_cards} +
    +
    + """) + + # Recommendations - organized sections + recs = response.get("recommendations", {}) + if recs: + rec_sections = "" + + immediate = recs.get("immediate_actions", []) + if immediate: + items = "".join([f'
  • {a}
  • ' for a in immediate[:3]]) + rec_sections += f""" +
    +
    🚨 Immediate Actions
    + +
    + """ + + lifestyle = recs.get("lifestyle_modifications", []) + if lifestyle: + items = "".join([f'
  • {m}
  • ' for m in lifestyle[:3]]) + rec_sections += f""" +
    +
    🌿 Lifestyle Modifications
    + +
    + """ + + followup = recs.get("follow_up", []) + if followup: + items = "".join([f'
  • {f}
  • ' for f in followup[:3]]) + rec_sections += f""" +
    +
    πŸ“… Follow-up
    + +
    + """ + + if rec_sections: + parts.append(f""" +
    +

    πŸ’‘ Recommendations

    + {rec_sections} +
    + """) + + # Disease Explanation + explanation = response.get("disease_explanation", {}) + if explanation and isinstance(explanation, dict): + pathophys = explanation.get("pathophysiology", "") + if pathophys: + parts.append(f""" +
    +

    πŸ“– Understanding Your Results

    +

    {pathophys[:600]}{'...' if len(pathophys) > 600 else ''}

    +
    + """) + + # Conversational Summary + conv_summary = response.get("conversational_summary", "") + if conv_summary: + parts.append(f""" +
    +

    πŸ“ Summary

    +

    {conv_summary[:1000]}

    +
    + """) + + # Footer + parts.append(f""" +
    +

    + ✨ Analysis completed in {elapsed:.1f}s using Agentic RagBot +

    +

    + ⚠️ This is for informational purposes only. Consult a healthcare professional for medical advice. +

    +
    + """) + + return "\n".join(parts) + + +# --------------------------------------------------------------------------- +# Gradio Interface +# --------------------------------------------------------------------------- + +# Custom CSS for modern medical UI +CUSTOM_CSS = """ +/* Global Styles */ +.gradio-container { + max-width: 1400px !important; + margin: auto !important; + font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; +} + +/* Hide footer */ +footer { display: none !important; } + +/* Header styling */ +.header-container { + background: linear-gradient(135deg, #1e3a5f 0%, #2d5a87 50%, #3d7ab5 100%); + border-radius: 16px; + padding: 32px; + margin-bottom: 24px; + color: white; + text-align: center; + box-shadow: 0 8px 32px rgba(30, 58, 95, 0.3); +} + +.header-container h1 { + margin: 0 0 12px 0; + font-size: 2.5em; + font-weight: 700; + text-shadow: 0 2px 4px rgba(0,0,0,0.2); +} + +.header-container p { + margin: 0; + opacity: 0.95; + font-size: 1.1em; +} + +/* Input panel */ +.input-panel { + background: linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%); + border-radius: 16px; + padding: 24px; + border: 1px solid #e2e8f0; + box-shadow: 0 4px 16px rgba(0, 0, 0, 0.05); +} + +/* Output panel */ +.output-panel { + background: white; + border-radius: 16px; + padding: 24px; + border: 1px solid #e2e8f0; + box-shadow: 0 4px 16px rgba(0, 0, 0, 0.05); + min-height: 500px; +} + +/* Status badges */ +.status-success { + background: linear-gradient(135deg, #10b981 0%, #059669 100%); + color: white; + padding: 12px 20px; + border-radius: 10px; + font-weight: 600; + display: inline-block; +} + +.status-error { + background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%); + color: white; + padding: 12px 20px; + border-radius: 10px; + font-weight: 600; +} + +.status-warning { + background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%); + color: white; + padding: 12px 20px; + border-radius: 10px; + font-weight: 600; +} + +/* Info banner */ +.info-banner { + background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%); + border: 1px solid #93c5fd; + border-radius: 12px; + padding: 16px 20px; + margin: 16px 0; + display: flex; + align-items: center; + gap: 12px; +} + +.info-banner-icon { + font-size: 1.5em; +} + +/* Agent cards */ +.agent-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); + gap: 16px; + margin: 20px 0; +} + +.agent-card { + background: linear-gradient(135deg, #ffffff 0%, #f8fafc 100%); + border: 1px solid #e2e8f0; + border-radius: 12px; + padding: 20px; + transition: all 0.3s ease; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.04); +} + +.agent-card:hover { + transform: translateY(-2px); + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.1); + border-color: #3b82f6; +} + +.agent-card h4 { + margin: 0 0 8px 0; + color: #1e3a5f; + font-size: 1em; +} + +.agent-card p { + margin: 0; + color: #64748b; + font-size: 0.9em; +} + +/* Example buttons */ +.example-btn { + background: #f1f5f9; + border: 1px solid #cbd5e1; + border-radius: 8px; + padding: 10px 14px; + cursor: pointer; + transition: all 0.2s ease; + text-align: left; + font-size: 0.85em; +} + +.example-btn:hover { + background: #e2e8f0; + border-color: #94a3b8; +} + +/* Buttons */ +.primary-btn { + background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%) !important; + border: none !important; + border-radius: 12px !important; + padding: 14px 28px !important; + font-weight: 600 !important; + font-size: 1.1em !important; + box-shadow: 0 4px 14px rgba(59, 130, 246, 0.4) !important; + transition: all 0.3s ease !important; +} + +.primary-btn:hover { + transform: translateY(-2px) !important; + box-shadow: 0 6px 20px rgba(59, 130, 246, 0.5) !important; +} + +.secondary-btn { + background: #f1f5f9 !important; + border: 1px solid #cbd5e1 !important; + border-radius: 12px !important; + padding: 14px 28px !important; + font-weight: 500 !important; + transition: all 0.2s ease !important; +} + +.secondary-btn:hover { + background: #e2e8f0 !important; +} + +/* Results tabs */ +.results-tabs { + border-radius: 12px; + overflow: hidden; +} + +/* Disclaimer */ +.disclaimer { + background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); + border: 1px solid #fbbf24; + border-radius: 12px; + padding: 16px 20px; + margin-top: 24px; + font-size: 0.9em; +} + +/* Feature badges */ +.feature-badge { + display: inline-block; + background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); + color: #4338ca; + padding: 6px 12px; + border-radius: 20px; + font-size: 0.8em; + font-weight: 600; + margin: 4px; +} + +/* Section titles */ +.section-title { + font-size: 1.25em; + font-weight: 600; + color: #1e3a5f; + margin-bottom: 16px; + display: flex; + align-items: center; + gap: 8px; +} + +/* Animations */ +@keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.7; } +} + +.analyzing { + animation: pulse 1.5s ease-in-out infinite; +} +""" + + +def create_demo() -> gr.Blocks: + """Create the Gradio Blocks interface with modern medical UI.""" + + with gr.Blocks( + title="Agentic RagBot - Medical Biomarker Analysis", + theme=gr.themes.Soft( + primary_hue=gr.themes.colors.blue, + secondary_hue=gr.themes.colors.slate, + neutral_hue=gr.themes.colors.slate, + font=gr.themes.GoogleFont("Inter"), + font_mono=gr.themes.GoogleFont("JetBrains Mono"), + ).set( + body_background_fill="linear-gradient(135deg, #f0f4f8 0%, #e2e8f0 100%)", + block_background_fill="white", + block_border_width="0px", + block_shadow="0 4px 16px rgba(0, 0, 0, 0.08)", + block_radius="16px", + button_primary_background_fill="linear-gradient(135deg, #3b82f6 0%, #2563eb 100%)", + button_primary_background_fill_hover="linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%)", + button_primary_text_color="white", + button_primary_shadow="0 4px 14px rgba(59, 130, 246, 0.4)", + input_background_fill="#f8fafc", + input_border_width="1px", + input_border_color="#e2e8f0", + input_radius="12px", + ), + css=CUSTOM_CSS, + ) as demo: + + # ===== HEADER ===== + gr.HTML(""" +
    +

    πŸ₯ Agentic RagBot

    +

    Multi-Agent RAG System for Medical Biomarker Analysis

    +
    + πŸ€– 6 AI Agents + πŸ“š RAG-Powered + ⚑ Real-time Analysis + πŸ”¬ Evidence-Based +
    +
    + """) + + # ===== API KEY INFO ===== + gr.HTML(""" +
    + πŸ”‘ +
    + Setup Required: Add your GROQ_API_KEY or + GOOGLE_API_KEY in Space Settings β†’ Secrets to enable analysis. + Get free Groq key β†’ +
    +
    + """) + + # ===== MAIN CONTENT ===== + with gr.Row(equal_height=False): + + # ----- LEFT PANEL: INPUT ----- + with gr.Column(scale=2, min_width=400): + gr.HTML('
    πŸ“ Enter Your Biomarkers
    ') + + with gr.Group(): + input_text = gr.Textbox( + label="", + placeholder="Enter biomarkers in any format:\n\nβ€’ Glucose: 140, HbA1c: 7.5, Cholesterol: 210\nβ€’ My glucose is 140 and HbA1c is 7.5\nβ€’ {\"Glucose\": 140, \"HbA1c\": 7.5}", + lines=6, + max_lines=12, + show_label=False, + ) + + with gr.Row(): + analyze_btn = gr.Button( + "πŸ”¬ Analyze Biomarkers", + variant="primary", + size="lg", + scale=3, + ) + clear_btn = gr.Button( + "πŸ—‘οΈ Clear", + variant="secondary", + size="lg", + scale=1, + ) + + # Status display + status_output = gr.Markdown( + value="", + elem_classes="status-box" + ) + + # Quick Examples + gr.HTML('
    ⚑ Quick Examples
    ') + gr.HTML('

    Click any example to load it instantly

    ') + + examples = gr.Examples( + examples=[ + ["Glucose: 185, HbA1c: 8.2, Cholesterol: 245, LDL: 165"], + ["Glucose: 95, HbA1c: 5.4, Cholesterol: 180, HDL: 55, LDL: 100"], + ["Hemoglobin: 9.5, Iron: 40, Ferritin: 15"], + ["TSH: 8.5, T4: 4.0, T3: 80"], + ["Creatinine: 2.5, BUN: 45, eGFR: 35"], + ], + inputs=input_text, + label="", + ) + + # Supported Biomarkers + with gr.Accordion("πŸ“Š Supported Biomarkers", open=False): + gr.HTML(""" +
    +
    +

    🩸 Diabetes

    +

    Glucose, HbA1c, Fasting Glucose, Insulin

    +
    +
    +

    ❀️ Cardiovascular

    +

    Cholesterol, LDL, HDL, Triglycerides

    +
    +
    +

    🫘 Kidney

    +

    Creatinine, BUN, eGFR, Uric Acid

    +
    +
    +

    🦴 Liver

    +

    ALT, AST, Bilirubin, Albumin

    +
    +
    +

    πŸ¦‹ Thyroid

    +

    TSH, T3, T4, Free T4

    +
    +
    +

    πŸ’‰ Blood

    +

    Hemoglobin, WBC, RBC, Platelets

    +
    +
    + """) + + # ----- RIGHT PANEL: RESULTS ----- + with gr.Column(scale=3, min_width=500): + gr.HTML('
    πŸ“Š Analysis Results
    ') + + with gr.Tabs() as result_tabs: + with gr.Tab("πŸ“‹ Summary", id="summary"): + summary_output = gr.Markdown( + value=""" +
    +
    πŸ”¬
    +

    Ready to Analyze

    +

    Enter your biomarkers on the left and click Analyze to get your personalized health insights.

    +
    + """, + elem_classes="summary-output" + ) + + with gr.Tab("πŸ” Detailed JSON", id="json"): + details_output = gr.Code( + label="", + language="json", + lines=30, + show_label=False, + ) + + # ===== HOW IT WORKS ===== + gr.HTML('
    πŸ€– How It Works
    ') + + gr.HTML(""" +
    +
    +

    πŸ”¬ Biomarker Analyzer

    +

    Validates your biomarker values against clinical reference ranges and flags any abnormalities.

    +
    +
    +

    πŸ“š Disease Explainer

    +

    Uses RAG to retrieve relevant medical literature and explain potential conditions.

    +
    +
    +

    πŸ”— Biomarker Linker

    +

    Connects your specific biomarker patterns to disease predictions with clinical evidence.

    +
    +
    +

    πŸ“‹ Clinical Guidelines

    +

    Retrieves evidence-based recommendations from 750+ pages of medical guidelines.

    +
    +
    +

    βœ… Confidence Assessor

    +

    Evaluates the reliability of findings based on data quality and evidence strength.

    +
    +
    +

    πŸ“ Response Synthesizer

    +

    Compiles all insights into a comprehensive, easy-to-understand patient report.

    +
    +
    + """) + + # ===== DISCLAIMER ===== + gr.HTML(""" +
    + ⚠️ Medical Disclaimer: This tool is for informational purposes only + and does not replace professional medical advice, diagnosis, or treatment. Always consult a qualified + healthcare provider with questions regarding a medical condition. The AI analysis is based on general + clinical guidelines and may not account for your specific medical history. +
    + """) + + # ===== FOOTER ===== + gr.HTML(""" +
    +

    Built with ❀️ using + LangGraph, + FAISS, and + Gradio +

    +

    Powered by Groq (LLaMA 3.3-70B) β€’ Open Source on GitHub

    +
    + """) + + # ===== EVENT HANDLERS ===== + analyze_btn.click( + fn=analyze_biomarkers, + inputs=[input_text], + outputs=[summary_output, details_output, status_output], + show_progress="full", + ) + + clear_btn.click( + fn=lambda: ("", """ +
    +
    πŸ”¬
    +

    Ready to Analyze

    +

    Enter your biomarkers on the left and click Analyze to get your personalized health insights.

    +
    + """, "", ""), + outputs=[input_text, summary_output, details_output, status_output], + ) + + return demo + + +# --------------------------------------------------------------------------- +# Main Entry Point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + logger.info("Starting MediGuard AI Gradio App...") + + demo = create_demo() + + # Launch with HF Spaces compatible settings + demo.launch( + server_name="0.0.0.0", + server_port=7860, + show_error=True, + # share=False on HF Spaces + ) diff --git a/huggingface/requirements.txt b/huggingface/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e7008b543ee0dbe144b93711828552c03ab7c79d --- /dev/null +++ b/huggingface/requirements.txt @@ -0,0 +1,42 @@ +# =========================================================================== +# MediGuard AI β€” Hugging Face Spaces Dependencies +# =========================================================================== +# Minimal dependencies for standalone Gradio deployment. +# No postgres, redis, opensearch, ollama required. +# =========================================================================== + +# --- Gradio UI --- +gradio>=5.0.0 + +# --- LangChain Core --- +langchain>=0.3.0 +langchain-community>=0.3.0 +langchain-core>=0.3.0 +langchain-text-splitters>=0.3.0 +langgraph>=0.2.0 + +# --- Cloud LLM Providers (FREE tiers) --- +langchain-groq>=0.2.0 +langchain-google-genai>=2.0.0 + +# --- Vector Store --- +faiss-cpu>=1.8.0 + +# --- Embeddings (local - no API key needed) --- +sentence-transformers>=3.0.0 +langchain-huggingface>=0.1.0 + +# --- Document Processing --- +pypdf>=4.0.0 + +# --- Pydantic --- +pydantic>=2.9.0 +pydantic-settings>=2.5.0 + +# --- HTTP Client --- +httpx>=0.27.0 + +# --- Utilities --- +python-dotenv>=1.0.0 +tenacity>=8.0.0 +numpy<2.0.0 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..4388e198e05679b3e6c440b7e571c7af1697bdd0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,117 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "mediguard-ai" +version = "2.0.0" +description = "Production medical biomarker analysis β€” agentic RAG + multi-agent workflow" +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.11" +authors = [{ name = "MediGuard AI Team" }] + +dependencies = [ + # --- Core --- + "fastapi>=0.115.0", + "uvicorn[standard]>=0.30.0", + "pydantic>=2.9.0", + "pydantic-settings>=2.5.0", + # --- LLM / LangChain --- + "langchain>=0.3.0", + "langchain-community>=0.3.0", + "langgraph>=0.2.0", + # --- Vector / Search --- + "opensearch-py>=2.7.0", + "faiss-cpu>=1.8.0", + # --- Embeddings --- + "httpx>=0.27.0", + # --- Database --- + "sqlalchemy>=2.0.0", + "psycopg2-binary>=2.9.0", + "alembic>=1.13.0", + # --- Cache --- + "redis>=5.0.0", + # --- PDF --- + "pypdf>=4.0.0", + # --- Observability --- + "langfuse>=2.0.0", + # --- Utilities --- + "python-dotenv>=1.0.0", + "tenacity>=8.0.0", +] + +[project.optional-dependencies] +docling = ["docling>=2.0.0"] +telegram = ["python-telegram-bot>=21.0", "httpx>=0.27.0"] +gradio = ["gradio>=5.0.0", "httpx>=0.27.0"] +airflow = ["apache-airflow>=2.9.0"] +google = ["langchain-google-genai>=2.0.0"] +groq = ["langchain-groq>=0.2.0"] +huggingface = ["sentence-transformers>=3.0.0"] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=5.0.0", + "ruff>=0.7.0", + "mypy>=1.12.0", + "pre-commit>=3.8.0", + "httpx>=0.27.0", +] +all = [ + "mediguard-ai[docling,telegram,gradio,google,groq,huggingface,dev]", +] + +[project.scripts] +mediguard = "src.main:app" +mediguard-telegram = "src.services.telegram.bot:MediGuardTelegramBot" +mediguard-gradio = "src.gradio_app:launch_gradio" + +# -------------------------------------------------------------------------- +# Ruff +# -------------------------------------------------------------------------- +[tool.ruff] +target-version = "py311" +line-length = 120 +fix = true + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "RUF", # ruff-specific +] +ignore = [ + "E501", # line too long β€” handled by formatter + "B008", # do not perform function calls in argument defaults (Depends) + "SIM108", # ternary operator +] + +[tool.ruff.lint.isort] +known-first-party = ["src"] + +# -------------------------------------------------------------------------- +# MyPy +# -------------------------------------------------------------------------- +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false # gradually enable +ignore_missing_imports = true + +# -------------------------------------------------------------------------- +# Pytest +# -------------------------------------------------------------------------- +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] +addopts = "-v --tb=short -q" +filterwarnings = ["ignore::DeprecationWarning"] diff --git a/scripts/deploy_huggingface.ps1 b/scripts/deploy_huggingface.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..2afc7f203511547a180656492698244b81283c64 --- /dev/null +++ b/scripts/deploy_huggingface.ps1 @@ -0,0 +1,139 @@ +<# +.SYNOPSIS + Deploy MediGuard AI to Hugging Face Spaces +.DESCRIPTION + This script automates the deployment of MediGuard AI to Hugging Face Spaces. + It handles copying files, setting up the Dockerfile, and pushing to the Space. +.PARAMETER SpaceName + Name of your Hugging Face Space (e.g., "mediguard-ai") +.PARAMETER Username + Your Hugging Face username +.PARAMETER SkipClone + Skip cloning if you've already cloned the Space +.EXAMPLE + .\deploy_huggingface.ps1 -Username "your-username" -SpaceName "mediguard-ai" +#> + +param( + [Parameter(Mandatory=$true)] + [string]$Username, + + [Parameter(Mandatory=$false)] + [string]$SpaceName = "mediguard-ai", + + [switch]$SkipClone +) + +$ErrorActionPreference = "Stop" + +Write-Host "========================================" -ForegroundColor Cyan +Write-Host " MediGuard AI - Hugging Face Deployment" -ForegroundColor Cyan +Write-Host "========================================" -ForegroundColor Cyan +Write-Host "" + +# Configuration +$ProjectRoot = Split-Path -Parent $PSScriptRoot +$DeployDir = Join-Path $ProjectRoot "hf-deploy" +$SpaceUrl = "https://huggingface.co/spaces/$Username/$SpaceName" + +Write-Host "Project Root: $ProjectRoot" -ForegroundColor Gray +Write-Host "Deploy Dir: $DeployDir" -ForegroundColor Gray +Write-Host "Space URL: $SpaceUrl" -ForegroundColor Gray +Write-Host "" + +# Step 1: Clone or use existing Space +if (-not $SkipClone) { + Write-Host "[1/6] Cloning Hugging Face Space..." -ForegroundColor Yellow + + if (Test-Path $DeployDir) { + Write-Host " Removing existing deploy directory..." -ForegroundColor Gray + Remove-Item -Recurse -Force $DeployDir + } + + git clone "https://huggingface.co/spaces/$Username/$SpaceName" $DeployDir + + if ($LASTEXITCODE -ne 0) { + Write-Host "ERROR: Failed to clone Space. Make sure it exists!" -ForegroundColor Red + Write-Host "Create it at: https://huggingface.co/new-space" -ForegroundColor Yellow + exit 1 + } +} else { + Write-Host "[1/6] Using existing deploy directory..." -ForegroundColor Yellow +} + +# Step 2: Copy project files +Write-Host "[2/6] Copying project files..." -ForegroundColor Yellow + +# Core directories +$CoreDirs = @("src", "config", "data", "huggingface") +foreach ($dir in $CoreDirs) { + $source = Join-Path $ProjectRoot $dir + $dest = Join-Path $DeployDir $dir + if (Test-Path $source) { + Write-Host " Copying $dir..." -ForegroundColor Gray + Copy-Item -Path $source -Destination $dest -Recurse -Force + } +} + +# Copy specific files +$CoreFiles = @("pyproject.toml", ".dockerignore") +foreach ($file in $CoreFiles) { + $source = Join-Path $ProjectRoot $file + if (Test-Path $source) { + Write-Host " Copying $file..." -ForegroundColor Gray + Copy-Item -Path $source -Destination (Join-Path $DeployDir $file) -Force + } +} + +# Step 3: Set up Dockerfile (HF Spaces expects it in root) +Write-Host "[3/6] Setting up Dockerfile..." -ForegroundColor Yellow +$HfDockerfile = Join-Path $DeployDir "huggingface/Dockerfile" +$RootDockerfile = Join-Path $DeployDir "Dockerfile" +Copy-Item -Path $HfDockerfile -Destination $RootDockerfile -Force +Write-Host " Copied huggingface/Dockerfile to Dockerfile" -ForegroundColor Gray + +# Step 4: Set up README with HF metadata +Write-Host "[4/6] Setting up README.md..." -ForegroundColor Yellow +$HfReadme = Join-Path $DeployDir "huggingface/README.md" +$RootReadme = Join-Path $DeployDir "README.md" +Copy-Item -Path $HfReadme -Destination $RootReadme -Force +Write-Host " Copied huggingface/README.md to README.md" -ForegroundColor Gray + +# Step 5: Verify vector store exists +Write-Host "[5/6] Verifying vector store..." -ForegroundColor Yellow +$VectorStore = Join-Path $DeployDir "data/vector_stores/medical_knowledge.faiss" +if (Test-Path $VectorStore) { + $size = (Get-Item $VectorStore).Length / 1MB + Write-Host " Vector store found: $([math]::Round($size, 2)) MB" -ForegroundColor Green +} else { + Write-Host " WARNING: Vector store not found!" -ForegroundColor Red + Write-Host " Run 'python scripts/setup_embeddings.py' first to create it." -ForegroundColor Yellow +} + +# Step 6: Commit and push +Write-Host "[6/6] Committing and pushing to Hugging Face..." -ForegroundColor Yellow + +Push-Location $DeployDir + +git add . +git commit -m "Deploy MediGuard AI - $(Get-Date -Format 'yyyy-MM-dd HH:mm')" + +Write-Host "" +Write-Host "Ready to push! Run the following command:" -ForegroundColor Green +Write-Host "" +Write-Host " cd $DeployDir" -ForegroundColor Cyan +Write-Host " git push" -ForegroundColor Cyan +Write-Host "" +Write-Host "After pushing, add your API key as a Secret in Space Settings:" -ForegroundColor Yellow +Write-Host " Name: GROQ_API_KEY (or GOOGLE_API_KEY)" -ForegroundColor Gray +Write-Host " Value: your-api-key" -ForegroundColor Gray +Write-Host "" +Write-Host "Your Space will be live at:" -ForegroundColor Green +Write-Host " $SpaceUrl" -ForegroundColor Cyan + +Pop-Location + +Write-Host "" +Write-Host "========================================" -ForegroundColor Cyan +Write-Host " Deployment prepared successfully!" -ForegroundColor Green +Write-Host "========================================" -ForegroundColor Cyan diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000000000000000000000000000000000000..6111e83049b25728ab827313378e402824733591 --- /dev/null +++ b/src/database.py @@ -0,0 +1,50 @@ +""" +MediGuard AI β€” Database layer + +Provides SQLAlchemy engine/session factories and the declarative Base. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Generator + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker, DeclarativeBase + +from src.settings import get_settings + + +class Base(DeclarativeBase): + """Shared declarative base for all ORM models.""" + pass + + +@lru_cache(maxsize=1) +def _engine(): + settings = get_settings() + return create_engine( + settings.postgres.database_url, + pool_pre_ping=True, + pool_size=5, + max_overflow=10, + echo=settings.debug, + ) + + +@lru_cache(maxsize=1) +def _session_factory() -> sessionmaker[Session]: + return sessionmaker(bind=_engine(), autocommit=False, autoflush=False) + + +def get_db() -> Generator[Session, None, None]: + """FastAPI dependency β€” yields a DB session and commits/rolls back.""" + session = _session_factory()() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/dependencies.py b/src/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..77f35d9756a7f03569e33eea1bee6164c48b4cc4 --- /dev/null +++ b/src/dependencies.py @@ -0,0 +1,36 @@ +""" +MediGuard AI β€” FastAPI Dependency Injection + +Provides factory functions and ``Depends()`` for services used across routers. +""" + +from __future__ import annotations + +from functools import lru_cache + +from src.settings import Settings, get_settings +from src.services.cache.redis_cache import RedisCache, make_redis_cache +from src.services.embeddings.service import EmbeddingService, make_embedding_service +from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer +from src.services.ollama.client import OllamaClient, make_ollama_client +from src.services.opensearch.client import OpenSearchClient, make_opensearch_client + + +def get_opensearch_client() -> OpenSearchClient: + return make_opensearch_client() + + +def get_embedding_service() -> EmbeddingService: + return make_embedding_service() + + +def get_redis_cache() -> RedisCache: + return make_redis_cache() + + +def get_ollama_client() -> OllamaClient: + return make_ollama_client() + + +def get_langfuse_tracer() -> LangfuseTracer: + return make_langfuse_tracer() diff --git a/src/exceptions.py b/src/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..ff58d21c4d2a647763985de61e6e43fc60079b6a --- /dev/null +++ b/src/exceptions.py @@ -0,0 +1,149 @@ +""" +MediGuard AI β€” Domain Exception Hierarchy + +Production-grade exception classes for the medical RAG system. +Each service layer raises its own exception type so callers can handle +failures precisely without leaking implementation details. +""" + +from typing import Any, Dict, Optional + + +# ── Base ────────────────────────────────────────────────────────────────────── + +class MediGuardError(Exception): + """Root exception for the entire MediGuard AI application.""" + + def __init__(self, message: str = "", *, details: Optional[Dict[str, Any]] = None): + self.details = details or {} + super().__init__(message) + + +# ── Configuration / startup ────────────────────────────────────────────────── + +class ConfigurationError(MediGuardError): + """Raised when a required setting is missing or invalid.""" + + +class ServiceInitError(MediGuardError): + """Raised when a service fails to initialise during app startup.""" + + +# ── Database ───────────────────────────────────────────────────────────────── + +class DatabaseError(MediGuardError): + """Base class for all database-related errors.""" + + +class ConnectionError(DatabaseError): + """Could not connect to PostgreSQL.""" + + +class RecordNotFoundError(DatabaseError): + """Expected record does not exist.""" + + +# ── Search engine ──────────────────────────────────────────────────────────── + +class SearchError(MediGuardError): + """Base class for search-engine (OpenSearch) errors.""" + + +class IndexNotFoundError(SearchError): + """The requested OpenSearch index does not exist.""" + + +class SearchQueryError(SearchError): + """The search query was malformed or returned an error.""" + + +# ── Embeddings ─────────────────────────────────────────────────────────────── + +class EmbeddingError(MediGuardError): + """Failed to generate embeddings.""" + + +class EmbeddingProviderError(EmbeddingError): + """The upstream embedding provider returned an error.""" + + +# ── PDF / document parsing ─────────────────────────────────────────────────── + +class PDFParsingError(MediGuardError): + """Base class for PDF-processing errors.""" + + +class PDFExtractionError(PDFParsingError): + """Could not extract text from a PDF document.""" + + +class PDFValidationError(PDFParsingError): + """Uploaded PDF failed validation (size, format, etc.).""" + + +# ── LLM / Ollama ───────────────────────────────────────────────────────────── + +class LLMError(MediGuardError): + """Base class for LLM-related errors.""" + + +class OllamaConnectionError(LLMError): + """Could not reach the Ollama server.""" + + +class OllamaModelNotFoundError(LLMError): + """The requested Ollama model is not pulled/available.""" + + +class LLMResponseError(LLMError): + """The LLM returned an unparseable or empty response.""" + + +# ── Biomarker domain ───────────────────────────────────────────────────────── + +class BiomarkerError(MediGuardError): + """Base class for biomarker-related errors.""" + + +class BiomarkerValidationError(BiomarkerError): + """A biomarker value is physiologically implausible.""" + + +class BiomarkerNotFoundError(BiomarkerError): + """The biomarker name is unknown to the system.""" + + +# ── Medical analysis / workflow ────────────────────────────────────────────── + +class AnalysisError(MediGuardError): + """The clinical-analysis workflow encountered an error.""" + + +class GuardrailError(MediGuardError): + """A safety guardrail was triggered (input or output).""" + + +class OutOfScopeError(GuardrailError): + """The user query falls outside the medical domain.""" + + +# ── Cache ──────────────────────────────────────────────────────────────────── + +class CacheError(MediGuardError): + """Base class for cache (Redis) errors.""" + + +class CacheConnectionError(CacheError): + """Could not connect to Redis.""" + + +# ── Observability ──────────────────────────────────────────────────────────── + +class ObservabilityError(MediGuardError): + """Langfuse or metrics reporting failed (non-fatal).""" + + +# ── Telegram bot ───────────────────────────────────────────────────────────── + +class TelegramError(MediGuardError): + """Error from the Telegram bot integration.""" diff --git a/src/gradio_app.py b/src/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..fee2e2d31434d4e7753e389e613764f7dae83366 --- /dev/null +++ b/src/gradio_app.py @@ -0,0 +1,121 @@ +""" +MediGuard AI β€” Gradio Web UI + +Provides a simple chat interface and biomarker analysis panel. +""" + +from __future__ import annotations + +import json +import logging +import os + +import httpx + +logger = logging.getLogger(__name__) + +API_BASE = os.getenv("MEDIGUARD_API_URL", "http://localhost:8000") + + +def _call_ask(question: str) -> str: + """Call the /ask endpoint.""" + try: + with httpx.Client(timeout=60.0) as client: + resp = client.post(f"{API_BASE}/ask", json={"question": question}) + resp.raise_for_status() + return resp.json().get("answer", "No answer returned.") + except Exception as exc: + return f"Error: {exc}" + + +def _call_analyze(biomarkers_json: str) -> str: + """Call the /analyze/structured endpoint.""" + try: + biomarkers = json.loads(biomarkers_json) + with httpx.Client(timeout=60.0) as client: + resp = client.post( + f"{API_BASE}/analyze/structured", + json={"biomarkers": biomarkers}, + ) + resp.raise_for_status() + data = resp.json() + summary = data.get("conversational_summary") or json.dumps(data, indent=2) + return summary + except json.JSONDecodeError: + return "Invalid JSON. Please enter biomarkers as: {\"Glucose\": 185, \"HbA1c\": 8.2}" + except Exception as exc: + return f"Error: {exc}" + + +def launch_gradio(share: bool = False) -> None: + """Launch the Gradio interface.""" + try: + import gradio as gr + except ImportError: + raise ImportError("gradio is required. Install: pip install gradio") + + with gr.Blocks(title="MediGuard AI", theme=gr.themes.Soft()) as demo: + gr.Markdown("# πŸ₯ MediGuard AI β€” Medical Analysis") + gr.Markdown( + "**Disclaimer**: This tool is for informational purposes only and does not " + "replace professional medical advice." + ) + + with gr.Tab("Ask a Question"): + question_input = gr.Textbox( + label="Medical Question", + placeholder="e.g., What does a high HbA1c level indicate?", + lines=3, + ) + ask_btn = gr.Button("Ask", variant="primary") + answer_output = gr.Textbox(label="Answer", lines=15, interactive=False) + ask_btn.click(fn=_call_ask, inputs=question_input, outputs=answer_output) + + with gr.Tab("Analyze Biomarkers"): + bio_input = gr.Textbox( + label="Biomarkers (JSON)", + placeholder='{"Glucose": 185, "HbA1c": 8.2, "Cholesterol": 210}', + lines=5, + ) + analyze_btn = gr.Button("Analyze", variant="primary") + analysis_output = gr.Textbox(label="Analysis", lines=20, interactive=False) + analyze_btn.click(fn=_call_analyze, inputs=bio_input, outputs=analysis_output) + + with gr.Tab("Search Knowledge Base"): + search_input = gr.Textbox( + label="Search Query", + placeholder="e.g., diabetes management guidelines", + lines=2, + ) + search_btn = gr.Button("Search", variant="primary") + search_output = gr.Textbox(label="Results", lines=15, interactive=False) + + def _call_search(query: str) -> str: + try: + with httpx.Client(timeout=30.0) as client: + resp = client.post( + f"{API_BASE}/search", + json={"query": query, "top_k": 5, "mode": "hybrid"}, + ) + resp.raise_for_status() + data = resp.json() + results = data.get("results", []) + if not results: + return "No results found." + parts = [] + for i, r in enumerate(results, 1): + parts.append( + f"**[{i}] {r.get('title', 'Untitled')}** (score: {r.get('score', 0):.3f})\n" + f"{r.get('text', '')}\n" + ) + return "\n---\n".join(parts) + except Exception as exc: + return f"Error: {exc}" + + search_btn.click(fn=_call_search, inputs=search_input, outputs=search_output) + + demo.launch(server_name="0.0.0.0", server_port=7860, share=share) + + +if __name__ == "__main__": + launch_gradio() diff --git a/src/llm_config.py b/src/llm_config.py index 3dfc14615630c3ea4dd1b91c97a98fba2c0e5fb6..a2c8023a637c3c232752d2a34d614635b3d4ad55 100644 --- a/src/llm_config.py +++ b/src/llm_config.py @@ -19,8 +19,14 @@ load_dotenv() # Configure LangSmith tracing os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT", "MediGuard_AI_RAG_Helper") -# Default provider (can be overridden via env) -DEFAULT_LLM_PROVIDER = os.getenv("LLM_PROVIDER", "groq") + +def get_default_llm_provider() -> str: + """Get default LLM provider dynamically from environment.""" + return os.getenv("LLM_PROVIDER", "groq") + + +# For backward compatibility (but prefer using get_default_llm_provider()) +DEFAULT_LLM_PROVIDER = get_default_llm_provider() def get_chat_model( @@ -41,7 +47,8 @@ def get_chat_model( Returns: LangChain chat model instance """ - provider = provider or DEFAULT_LLM_PROVIDER + # Use dynamic lookup to get current provider from environment + provider = provider or get_default_llm_provider() if provider == "groq": from langchain_groq import ChatGroq @@ -164,9 +171,11 @@ class LLMConfig: provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local) lazy: If True, defer model initialization until first use (avoids API key errors at import) """ - self.provider = provider or DEFAULT_LLM_PROVIDER + # Store explicit provider or None to use dynamic lookup later + self._explicit_provider = provider self._lazy = lazy self._initialized = False + self._initialized_provider = None # Track which provider was initialized self._lock = threading.Lock() # Lazy-initialized model instances @@ -181,8 +190,28 @@ class LLMConfig: if not lazy: self._initialize_models() + @property + def provider(self) -> str: + """Get current provider (dynamic lookup if not explicitly set).""" + return self._explicit_provider or get_default_llm_provider() + + def _check_provider_change(self): + """Check if provider changed and reinitialize if needed.""" + current = self.provider + if self._initialized and self._initialized_provider != current: + print(f"Provider changed from {self._initialized_provider} to {current}, reinitializing...") + self._initialized = False + self._planner = None + self._analyzer = None + self._explainer = None + self._synthesizer_7b = None + self._synthesizer_8b = None + self._director = None + def _initialize_models(self): """Initialize all model clients (called on first use if lazy)""" + self._check_provider_change() + if self._initialized: return @@ -234,6 +263,7 @@ class LLMConfig: self._embedding_model = get_embedding_model() self._initialized = True + self._initialized_provider = self.provider @property def planner(self): diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..c60b542d7996160f238b7c9db4362fd85e121dfa --- /dev/null +++ b/src/main.py @@ -0,0 +1,220 @@ +""" +MediGuard AI β€” Production FastAPI Application + +Central app factory with lifespan that initialises all production services +(OpenSearch, Redis, Ollama, Langfuse, RAG pipeline) and gracefully shuts +them down. The existing ``api/`` package is kept as-is β€” this new module +becomes the primary production entry-point. +""" + +from __future__ import annotations + +import logging +import os +import time +from contextlib import asynccontextmanager +from datetime import datetime, timezone + +from fastapi import FastAPI, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from src.settings import get_settings + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(name)-30s | %(levelname)-7s | %(message)s", +) +logger = logging.getLogger("mediguard") + +# --------------------------------------------------------------------------- +# Lifespan +# --------------------------------------------------------------------------- + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialise production services on startup, tear them down on shutdown.""" + settings = get_settings() + app.state.start_time = time.time() + app.state.version = "2.0.0" + + logger.info("=" * 70) + logger.info("MediGuard AI β€” starting production server v%s", app.state.version) + logger.info("=" * 70) + + # --- OpenSearch --- + try: + from src.services.opensearch.client import make_opensearch_client + app.state.opensearch_client = make_opensearch_client() + logger.info("OpenSearch client ready") + except Exception as exc: + logger.warning("OpenSearch unavailable: %s", exc) + app.state.opensearch_client = None + + # --- Embedding service --- + try: + from src.services.embeddings.service import make_embedding_service + app.state.embedding_service = make_embedding_service() + logger.info("Embedding service ready (provider=%s)", app.state.embedding_service._provider) + except Exception as exc: + logger.warning("Embedding service unavailable: %s", exc) + app.state.embedding_service = None + + # --- Redis cache --- + try: + from src.services.cache.redis_cache import make_redis_cache + app.state.cache = make_redis_cache() + logger.info("Redis cache ready") + except Exception as exc: + logger.warning("Redis cache unavailable: %s", exc) + app.state.cache = None + + # --- Ollama LLM --- + try: + from src.services.ollama.client import make_ollama_client + app.state.ollama_client = make_ollama_client() + logger.info("Ollama client ready") + except Exception as exc: + logger.warning("Ollama client unavailable: %s", exc) + app.state.ollama_client = None + + # --- Langfuse tracer --- + try: + from src.services.langfuse.tracer import make_langfuse_tracer + app.state.tracer = make_langfuse_tracer() + logger.info("Langfuse tracer ready") + except Exception as exc: + logger.warning("Langfuse tracer unavailable: %s", exc) + app.state.tracer = None + + # --- Agentic RAG service --- + try: + from src.services.agents.agentic_rag import AgenticRAGService + from src.services.agents.context import AgenticContext + + if app.state.ollama_client and app.state.opensearch_client and app.state.embedding_service: + llm = app.state.ollama_client.get_langchain_model() + ctx = AgenticContext( + llm=llm, + embedding_service=app.state.embedding_service, + opensearch_client=app.state.opensearch_client, + cache=app.state.cache, + tracer=app.state.tracer, + ) + app.state.rag_service = AgenticRAGService(ctx) + logger.info("Agentic RAG service ready") + else: + app.state.rag_service = None + logger.warning("Agentic RAG service skipped β€” missing backing services") + except Exception as exc: + logger.warning("Agentic RAG service failed: %s", exc) + app.state.rag_service = None + + # --- Legacy RagBot service (backward-compatible /analyze) --- + try: + from api.app.services.ragbot import get_ragbot_service + ragbot = get_ragbot_service() + ragbot.initialize() + app.state.ragbot_service = ragbot + logger.info("Legacy RagBot service ready") + except Exception as exc: + logger.warning("Legacy RagBot service unavailable: %s", exc) + app.state.ragbot_service = None + + logger.info("All services initialised β€” ready to serve") + logger.info("=" * 70) + + yield # ---- server running ---- + + logger.info("Shutting down MediGuard AI …") + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + +def create_app() -> FastAPI: + """Build and return the configured FastAPI application.""" + settings = get_settings() + + app = FastAPI( + title="MediGuard AI", + description="Production medical biomarker analysis β€” agentic RAG + multi-agent workflow", + version="2.0.0", + lifespan=lifespan, + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + ) + + # --- CORS --- + origins = os.getenv("CORS_ALLOWED_ORIGINS", "*").split(",") + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=origins != ["*"], + allow_methods=["*"], + allow_headers=["*"], + ) + + # --- Exception handlers --- + @app.exception_handler(RequestValidationError) + async def validation_error(request: Request, exc: RequestValidationError): + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "status": "error", + "error_code": "VALIDATION_ERROR", + "message": "Request validation failed", + "details": exc.errors(), + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + + @app.exception_handler(Exception) + async def catch_all(request: Request, exc: Exception): + logger.error("Unhandled exception: %s", exc, exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "status": "error", + "error_code": "INTERNAL_SERVER_ERROR", + "message": "An unexpected error occurred. Please try again later.", + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + + # --- Routers --- + from src.routers import health, analyze, ask, search + + app.include_router(health.router) + app.include_router(analyze.router) + app.include_router(ask.router) + app.include_router(search.router) + + @app.get("/") + async def root(): + return { + "name": "MediGuard AI", + "version": "2.0.0", + "status": "online", + "endpoints": { + "health": "/health", + "health_ready": "/health/ready", + "analyze_natural": "/analyze/natural", + "analyze_structured": "/analyze/structured", + "ask": "/ask", + "search": "/search", + "docs": "/docs", + }, + } + + return app + + +# Module-level app for ``uvicorn src.main:app`` +app = create_app() diff --git a/src/repositories/__init__.py b/src/repositories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cee15d48ed0eb4ca756a0f10cfbce3031b666b4c --- /dev/null +++ b/src/repositories/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” Repositories package.""" diff --git a/src/repositories/analysis.py b/src/repositories/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0f06fca1241c09b8dccbe63ac52607245d3503 --- /dev/null +++ b/src/repositories/analysis.py @@ -0,0 +1,41 @@ +""" +MediGuard AI β€” Analysis repository (data-access layer). +""" + +from __future__ import annotations + +from typing import List, Optional + +from sqlalchemy.orm import Session + +from src.models.analysis import PatientAnalysis + + +class AnalysisRepository: + """CRUD operations for patient analyses.""" + + def __init__(self, db: Session): + self.db = db + + def create(self, analysis: PatientAnalysis) -> PatientAnalysis: + self.db.add(analysis) + self.db.flush() + return analysis + + def get_by_request_id(self, request_id: str) -> Optional[PatientAnalysis]: + return ( + self.db.query(PatientAnalysis) + .filter(PatientAnalysis.request_id == request_id) + .first() + ) + + def list_recent(self, limit: int = 20) -> List[PatientAnalysis]: + return ( + self.db.query(PatientAnalysis) + .order_by(PatientAnalysis.created_at.desc()) + .limit(limit) + .all() + ) + + def count(self) -> int: + return self.db.query(PatientAnalysis).count() diff --git a/src/repositories/document.py b/src/repositories/document.py new file mode 100644 index 0000000000000000000000000000000000000000..39115a631a041c46eb5c9dcfdba2d77a85fc1c6c --- /dev/null +++ b/src/repositories/document.py @@ -0,0 +1,48 @@ +""" +MediGuard AI β€” Document repository. +""" + +from __future__ import annotations + +from typing import List, Optional + +from sqlalchemy.orm import Session + +from src.models.analysis import MedicalDocument + + +class DocumentRepository: + """CRUD for ingested medical documents.""" + + def __init__(self, db: Session): + self.db = db + + def upsert(self, doc: MedicalDocument) -> MedicalDocument: + existing = ( + self.db.query(MedicalDocument) + .filter(MedicalDocument.content_hash == doc.content_hash) + .first() + ) + if existing: + existing.parse_status = doc.parse_status + existing.chunk_count = doc.chunk_count + existing.indexed_at = doc.indexed_at + self.db.flush() + return existing + self.db.add(doc) + self.db.flush() + return doc + + def get_by_id(self, doc_id: str) -> Optional[MedicalDocument]: + return self.db.query(MedicalDocument).filter(MedicalDocument.id == doc_id).first() + + def list_all(self, limit: int = 100) -> List[MedicalDocument]: + return ( + self.db.query(MedicalDocument) + .order_by(MedicalDocument.created_at.desc()) + .limit(limit) + .all() + ) + + def count(self) -> int: + return self.db.query(MedicalDocument).count() diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fca7119d3e666ca87d1efac3e9e54341814db5ad --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” Production API routers.""" diff --git a/src/routers/analyze.py b/src/routers/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..573302b31bd368bf1a960b39884d4c0ddf504dd6 --- /dev/null +++ b/src/routers/analyze.py @@ -0,0 +1,88 @@ +""" +MediGuard AI β€” Analyze Router + +Backward-compatible /analyze/natural and /analyze/structured endpoints +that delegate to the existing ClinicalInsightGuild workflow. +""" + +from __future__ import annotations + +import logging +import time +import uuid +from datetime import datetime, timezone +from typing import Any, Dict + +from fastapi import APIRouter, HTTPException, Request + +from src.schemas.schemas import ( + AnalysisResponse, + ErrorResponse, + NaturalAnalysisRequest, + StructuredAnalysisRequest, +) + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/analyze", tags=["analysis"]) + + +async def _run_guild_analysis( + request: Request, + biomarkers: Dict[str, float], + patient_ctx: Dict[str, Any], + extracted_biomarkers: Dict[str, float] | None = None, +) -> AnalysisResponse: + """Execute the ClinicalInsightGuild and build the response envelope.""" + request_id = f"req_{uuid.uuid4().hex[:12]}" + t0 = time.time() + + ragbot = getattr(request.app.state, "ragbot_service", None) + if ragbot is None: + raise HTTPException(status_code=503, detail="Analysis service unavailable") + + try: + result = await ragbot.analyze(biomarkers, patient_ctx) + except Exception as exc: + logger.exception("Guild analysis failed: %s", exc) + raise HTTPException( + status_code=500, + detail=f"Analysis pipeline error: {exc}", + ) + + elapsed = (time.time() - t0) * 1000 + + # The guild returns a dict shaped like AnalysisResponse β€” pass through + return AnalysisResponse( + status="success", + request_id=request_id, + timestamp=datetime.now(timezone.utc).isoformat(), + extracted_biomarkers=extracted_biomarkers, + input_biomarkers=biomarkers, + patient_context=patient_ctx, + processing_time_ms=round(elapsed, 1), + **{k: v for k, v in result.items() if k not in ("status", "request_id", "timestamp", "extracted_biomarkers", "input_biomarkers", "patient_context", "processing_time_ms")}, + ) + + +@router.post("/natural", response_model=AnalysisResponse) +async def analyze_natural(body: NaturalAnalysisRequest, request: Request): + """Extract biomarkers from natural language and run full analysis.""" + extraction_svc = getattr(request.app.state, "extraction_service", None) + if extraction_svc is None: + raise HTTPException(status_code=503, detail="Extraction service unavailable") + + try: + extracted = await extraction_svc.extract_biomarkers(body.message) + except Exception as exc: + logger.exception("Biomarker extraction failed: %s", exc) + raise HTTPException(status_code=422, detail=f"Could not extract biomarkers: {exc}") + + patient_ctx = body.patient_context.model_dump(exclude_none=True) if body.patient_context else {} + return await _run_guild_analysis(request, extracted, patient_ctx, extracted_biomarkers=extracted) + + +@router.post("/structured", response_model=AnalysisResponse) +async def analyze_structured(body: StructuredAnalysisRequest, request: Request): + """Run full analysis on pre-structured biomarker data.""" + patient_ctx = body.patient_context.model_dump(exclude_none=True) if body.patient_context else {} + return await _run_guild_analysis(request, body.biomarkers, patient_ctx) diff --git a/src/routers/ask.py b/src/routers/ask.py new file mode 100644 index 0000000000000000000000000000000000000000..37b7897c9ee8e284704eb275cd40e738a9271e4e --- /dev/null +++ b/src/routers/ask.py @@ -0,0 +1,53 @@ +""" +MediGuard AI β€” Ask Router + +Free-form medical Q&A powered by the agentic RAG pipeline. +""" + +from __future__ import annotations + +import logging +import time +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, HTTPException, Request + +from src.schemas.schemas import AskRequest, AskResponse + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["ask"]) + + +@router.post("/ask", response_model=AskResponse) +async def ask_medical_question(body: AskRequest, request: Request): + """Answer a free-form medical question via agentic RAG.""" + rag_service = getattr(request.app.state, "rag_service", None) + if rag_service is None: + raise HTTPException(status_code=503, detail="RAG service unavailable") + + request_id = f"req_{uuid.uuid4().hex[:12]}" + t0 = time.time() + + try: + result = rag_service.ask( + query=body.question, + biomarkers=body.biomarkers, + patient_context=body.patient_context or "", + ) + except Exception as exc: + logger.exception("Agentic RAG failed: %s", exc) + raise HTTPException(status_code=500, detail=f"RAG pipeline error: {exc}") + + elapsed = (time.time() - t0) * 1000 + + return AskResponse( + status="success", + request_id=request_id, + question=body.question, + answer=result.get("final_answer", ""), + guardrail_score=result.get("guardrail_score"), + documents_retrieved=len(result.get("retrieved_documents", [])), + documents_relevant=len(result.get("relevant_documents", [])), + processing_time_ms=round(elapsed, 1), + ) diff --git a/src/routers/health.py b/src/routers/health.py new file mode 100644 index 0000000000000000000000000000000000000000..8b17330d2c0b9c84faef4ba848811e6e8b2065bf --- /dev/null +++ b/src/routers/health.py @@ -0,0 +1,101 @@ +""" +MediGuard AI β€” Health Router + +Provides /health and /health/ready with per-service checks. +""" + +from __future__ import annotations + +import time +from datetime import datetime, timezone + +from fastapi import APIRouter, Request + +from src.schemas.schemas import HealthResponse, ServiceHealth + +router = APIRouter(tags=["health"]) + + +@router.get("/health", response_model=HealthResponse) +async def health_check(request: Request) -> HealthResponse: + """Shallow liveness probe.""" + app_state = request.app.state + uptime = time.time() - getattr(app_state, "start_time", time.time()) + return HealthResponse( + status="healthy", + timestamp=datetime.now(timezone.utc).isoformat(), + version=getattr(app_state, "version", "2.0.0"), + uptime_seconds=round(uptime, 2), + ) + + +@router.get("/health/ready", response_model=HealthResponse) +async def readiness_check(request: Request) -> HealthResponse: + """Deep readiness probe β€” checks all backing services.""" + app_state = request.app.state + uptime = time.time() - getattr(app_state, "start_time", time.time()) + services: list[ServiceHealth] = [] + overall = "healthy" + + # --- OpenSearch --- + try: + os_client = getattr(app_state, "opensearch_client", None) + if os_client is not None: + t0 = time.time() + info = os_client.health() + latency = (time.time() - t0) * 1000 + os_status = info.get("status", "unknown") + services.append(ServiceHealth(name="opensearch", status="ok" if os_status in ("green", "yellow") else "degraded", latency_ms=round(latency, 1))) + else: + services.append(ServiceHealth(name="opensearch", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="opensearch", status="unavailable", detail=str(exc))) + overall = "degraded" + + # --- Redis --- + try: + cache = getattr(app_state, "cache", None) + if cache is not None: + t0 = time.time() + cache.set("__health__", "ok", ttl=10) + latency = (time.time() - t0) * 1000 + services.append(ServiceHealth(name="redis", status="ok", latency_ms=round(latency, 1))) + else: + services.append(ServiceHealth(name="redis", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="redis", status="unavailable", detail=str(exc))) + + # --- Ollama --- + try: + ollama = getattr(app_state, "ollama_client", None) + if ollama is not None: + t0 = time.time() + healthy = ollama.health() + latency = (time.time() - t0) * 1000 + services.append(ServiceHealth(name="ollama", status="ok" if healthy else "degraded", latency_ms=round(latency, 1))) + else: + services.append(ServiceHealth(name="ollama", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="ollama", status="unavailable", detail=str(exc))) + overall = "degraded" + + # --- Langfuse --- + try: + tracer = getattr(app_state, "tracer", None) + if tracer is not None: + services.append(ServiceHealth(name="langfuse", status="ok")) + else: + services.append(ServiceHealth(name="langfuse", status="unavailable")) + except Exception as exc: + services.append(ServiceHealth(name="langfuse", status="unavailable", detail=str(exc))) + + if any(s.status == "unavailable" for s in services if s.name in ("opensearch", "ollama")): + overall = "unhealthy" + + return HealthResponse( + status=overall, + timestamp=datetime.now(timezone.utc).isoformat(), + version=getattr(app_state, "version", "2.0.0"), + uptime_seconds=round(uptime, 2), + services=services, + ) diff --git a/src/routers/search.py b/src/routers/search.py new file mode 100644 index 0000000000000000000000000000000000000000..d3edab9a9b3068a66b781dcad3ad9b3f14257317 --- /dev/null +++ b/src/routers/search.py @@ -0,0 +1,72 @@ +""" +MediGuard AI β€” Search Router + +Direct hybrid search endpoint (no LLM generation). +""" + +from __future__ import annotations + +import logging +import time + +from fastapi import APIRouter, HTTPException, Request + +from src.schemas.schemas import SearchRequest, SearchResponse + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["search"]) + + +@router.post("/search", response_model=SearchResponse) +async def hybrid_search(body: SearchRequest, request: Request): + """Execute a direct hybrid search against the OpenSearch index.""" + os_client = getattr(request.app.state, "opensearch_client", None) + embedding_service = getattr(request.app.state, "embedding_service", None) + + if os_client is None: + raise HTTPException(status_code=503, detail="Search service unavailable") + + t0 = time.time() + + try: + if body.mode == "bm25": + results = os_client.search_bm25(query_text=body.query, top_k=body.top_k) + elif body.mode == "vector": + if embedding_service is None: + raise HTTPException(status_code=503, detail="Embedding service unavailable for vector search") + vec = embedding_service.embed_query(body.query) + results = os_client.search_vector(query_vector=vec, top_k=body.top_k) + else: + # hybrid + if embedding_service is None: + logger.warning("Embedding service unavailable β€” falling back to BM25") + results = os_client.search_bm25(query_text=body.query, top_k=body.top_k) + else: + vec = embedding_service.embed_query(body.query) + results = os_client.search_hybrid(query_text=body.query, query_vector=vec, top_k=body.top_k) + except HTTPException: + raise + except Exception as exc: + logger.exception("Search failed: %s", exc) + raise HTTPException(status_code=500, detail=f"Search error: {exc}") + + elapsed = (time.time() - t0) * 1000 + + formatted = [ + { + "id": hit.get("_id", ""), + "score": hit.get("_score", 0.0), + "title": hit.get("_source", {}).get("title", ""), + "section": hit.get("_source", {}).get("section_title", ""), + "text": hit.get("_source", {}).get("chunk_text", "")[:500], + } + for hit in results + ] + + return SearchResponse( + query=body.query, + mode=body.mode, + total_hits=len(formatted), + results=formatted, + processing_time_ms=round(elapsed, 1), + ) diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeeb3dd500a81b1183465f70a77f7943b4919a9e --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” API request/response schemas.""" diff --git a/src/schemas/schemas.py b/src/schemas/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..50bfe95d55ef9592e4e79577388015c027b00a9d --- /dev/null +++ b/src/schemas/schemas.py @@ -0,0 +1,247 @@ +""" +MediGuard AI β€” Production API Schemas + +Pydantic v2 request/response models for the new production API layer. +Keeps backward compatibility with existing schemas where possible. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +# ============================================================================ +# REQUEST MODELS +# ============================================================================ + + +class PatientContext(BaseModel): + """Patient demographic and context information.""" + + age: Optional[int] = Field(None, ge=0, le=120, description="Patient age in years") + gender: Optional[str] = Field(None, description="Patient gender (male/female)") + bmi: Optional[float] = Field(None, ge=10, le=60, description="Body Mass Index") + patient_id: Optional[str] = Field(None, description="Patient identifier") + + +class NaturalAnalysisRequest(BaseModel): + """Natural language biomarker analysis request.""" + + message: str = Field( + ..., min_length=5, max_length=2000, + description="Natural language message with biomarker values", + ) + patient_context: Optional[PatientContext] = Field( + default_factory=PatientContext, + ) + + +class StructuredAnalysisRequest(BaseModel): + """Structured biomarker analysis request.""" + + biomarkers: Dict[str, float] = Field( + ..., description="Dict of biomarker name β†’ measured value", + ) + patient_context: Optional[PatientContext] = Field( + default_factory=PatientContext, + ) + + @field_validator("biomarkers") + @classmethod + def biomarkers_not_empty(cls, v: Dict[str, float]) -> Dict[str, float]: + if not v: + raise ValueError("biomarkers must contain at least one entry") + return v + + +class AskRequest(BaseModel): + """Free‑form medical question (agentic RAG pipeline).""" + + question: str = Field( + ..., min_length=3, max_length=4000, + description="Medical question", + ) + biomarkers: Optional[Dict[str, float]] = Field( + None, description="Optional biomarker context", + ) + patient_context: Optional[str] = Field( + None, description="Free‑text patient context", + ) + + +class SearchRequest(BaseModel): + """Direct hybrid search (no LLM generation).""" + + query: str = Field(..., min_length=2, max_length=1000) + top_k: int = Field(10, ge=1, le=100) + mode: str = Field("hybrid", description="Search mode: bm25 | vector | hybrid") + + +# ============================================================================ +# RESPONSE BUILDING BLOCKS +# ============================================================================ + + +class BiomarkerFlag(BaseModel): + name: str + value: float + unit: str + status: str + reference_range: str + warning: Optional[str] = None + + +class SafetyAlert(BaseModel): + severity: str + biomarker: Optional[str] = None + message: str + action: str + + +class KeyDriver(BaseModel): + biomarker: str + value: Any + contribution: Optional[str] = None + explanation: str + evidence: Optional[str] = None + + +class Prediction(BaseModel): + disease: str + confidence: float = Field(ge=0, le=1) + probabilities: Dict[str, float] + + +class DiseaseExplanation(BaseModel): + pathophysiology: str + citations: List[str] = Field(default_factory=list) + retrieved_chunks: Optional[List[Dict[str, Any]]] = None + + +class Recommendations(BaseModel): + immediate_actions: List[str] = Field(default_factory=list) + lifestyle_changes: List[str] = Field(default_factory=list) + monitoring: List[str] = Field(default_factory=list) + follow_up: Optional[str] = None + + +class ConfidenceAssessment(BaseModel): + prediction_reliability: str + evidence_strength: str + limitations: List[str] = Field(default_factory=list) + reasoning: Optional[str] = None + + +class AgentOutput(BaseModel): + agent_name: str + findings: Any + metadata: Optional[Dict[str, Any]] = None + execution_time_ms: Optional[float] = None + + +class Analysis(BaseModel): + biomarker_flags: List[BiomarkerFlag] + safety_alerts: List[SafetyAlert] + key_drivers: List[KeyDriver] + disease_explanation: DiseaseExplanation + recommendations: Recommendations + confidence_assessment: ConfidenceAssessment + alternative_diagnoses: Optional[List[Dict[str, Any]]] = None + + +# ============================================================================ +# TOP‑LEVEL RESPONSES +# ============================================================================ + + +class AnalysisResponse(BaseModel): + """Full clinical analysis response (backward‑compatible).""" + + status: str + request_id: str + timestamp: str + extracted_biomarkers: Optional[Dict[str, float]] = None + input_biomarkers: Dict[str, float] + patient_context: Dict[str, Any] + prediction: Prediction + analysis: Analysis + agent_outputs: List[AgentOutput] + workflow_metadata: Dict[str, Any] + conversational_summary: Optional[str] = None + processing_time_ms: float + sop_version: Optional[str] = None + + +class AskResponse(BaseModel): + """Response from the agentic RAG /ask endpoint.""" + + status: str = "success" + request_id: str + question: str + answer: str + guardrail_score: Optional[float] = None + documents_retrieved: int = 0 + documents_relevant: int = 0 + processing_time_ms: float = 0.0 + + +class SearchResponse(BaseModel): + """Direct hybrid search response.""" + + status: str = "success" + query: str + mode: str + total_hits: int + results: List[Dict[str, Any]] + processing_time_ms: float = 0.0 + + +class ErrorResponse(BaseModel): + """Error envelope.""" + + status: str = "error" + error_code: str + message: str + details: Optional[Dict[str, Any]] = None + timestamp: str + request_id: Optional[str] = None + + +# ============================================================================ +# HEALTH / INFO +# ============================================================================ + + +class ServiceHealth(BaseModel): + name: str + status: str # ok | degraded | unavailable + latency_ms: Optional[float] = None + detail: Optional[str] = None + + +class HealthResponse(BaseModel): + """Production health check.""" + + status: str # healthy | degraded | unhealthy + timestamp: str + version: str + uptime_seconds: float + services: List[ServiceHealth] = Field(default_factory=list) + + +class BiomarkerReferenceRange(BaseModel): + min: Optional[float] = None + max: Optional[float] = None + male: Optional[Dict[str, float]] = None + female: Optional[Dict[str, float]] = None + + +class BiomarkerInfo(BaseModel): + name: str + unit: str + normal_range: BiomarkerReferenceRange + critical_low: Optional[float] = None + critical_high: Optional[float] = None diff --git a/src/services/agents/__init__.py b/src/services/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8659cd7fd04e8a55784beba77783f0c75ec1911 --- /dev/null +++ b/src/services/agents/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” Agentic RAG agents package.""" diff --git a/src/services/agents/agentic_rag.py b/src/services/agents/agentic_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..c2fc62d8168f835f6250e0a972d421fc006a9f52 --- /dev/null +++ b/src/services/agents/agentic_rag.py @@ -0,0 +1,158 @@ +""" +MediGuard AI β€” Agentic RAG Orchestrator + +LangGraph StateGraph that wires all nodes into the guardrail β†’ retrieve β†’ grade β†’ generate pipeline. +""" + +from __future__ import annotations + +import logging +from functools import lru_cache, partial +from typing import Any + +from langgraph.graph import END, StateGraph + +from src.services.agents.context import AgenticContext +from src.services.agents.nodes.generate_answer_node import generate_answer_node +from src.services.agents.nodes.grade_documents_node import grade_documents_node +from src.services.agents.nodes.guardrail_node import guardrail_node +from src.services.agents.nodes.out_of_scope_node import out_of_scope_node +from src.services.agents.nodes.retrieve_node import retrieve_node +from src.services.agents.nodes.rewrite_query_node import rewrite_query_node +from src.services.agents.state import AgenticRAGState + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Edge routing helpers +# --------------------------------------------------------------------------- + + +def _route_after_guardrail(state: dict) -> str: + """Decide path after guardrail evaluation.""" + if state.get("routing_decision") == "analyze": + # Biomarker analysis pathway β€” goes straight to retrieve + return "retrieve" + if state.get("is_in_scope"): + return "retrieve" + return "out_of_scope" + + +def _route_after_grading(state: dict) -> str: + """Decide whether to rewrite query or proceed to generation.""" + if state.get("needs_rewrite"): + return "rewrite_query" + if not state.get("relevant_documents"): + return "generate_answer" # will produce a "no evidence found" answer + return "generate_answer" + + +# --------------------------------------------------------------------------- +# Graph builder +# --------------------------------------------------------------------------- + + +def build_agentic_rag_graph(context: AgenticContext) -> Any: + """Construct the compiled LangGraph for the agentic RAG pipeline. + + Parameters + ---------- + context: + Runtime dependencies (LLM, OpenSearch, embeddings, cache, tracer). + + Returns + ------- + Compiled LangGraph graph ready for ``.invoke()`` / ``.stream()``. + """ + workflow = StateGraph(AgenticRAGState) + + # Bind context to every node via functools.partial + workflow.add_node("guardrail", partial(guardrail_node, context=context)) + workflow.add_node("retrieve", partial(retrieve_node, context=context)) + workflow.add_node("grade_documents", partial(grade_documents_node, context=context)) + workflow.add_node("rewrite_query", partial(rewrite_query_node, context=context)) + workflow.add_node("generate_answer", partial(generate_answer_node, context=context)) + workflow.add_node("out_of_scope", partial(out_of_scope_node, context=context)) + + # Entry point + workflow.set_entry_point("guardrail") + + # Conditional edges + workflow.add_conditional_edges( + "guardrail", + _route_after_guardrail, + { + "retrieve": "retrieve", + "out_of_scope": "out_of_scope", + }, + ) + + workflow.add_edge("retrieve", "grade_documents") + + workflow.add_conditional_edges( + "grade_documents", + _route_after_grading, + { + "rewrite_query": "rewrite_query", + "generate_answer": "generate_answer", + }, + ) + + # After rewrite, loop back to retrieve + workflow.add_edge("rewrite_query", "retrieve") + + # Terminal edges + workflow.add_edge("generate_answer", END) + workflow.add_edge("out_of_scope", END) + + return workflow.compile() + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +class AgenticRAGService: + """High-level wrapper around the compiled RAG graph.""" + + def __init__(self, context: AgenticContext) -> None: + self._context = context + self._graph = build_agentic_rag_graph(context) + + def ask( + self, + query: str, + biomarkers: dict | None = None, + patient_context: str = "", + ) -> dict: + """Run the full agentic RAG pipeline and return the final state.""" + initial_state: dict[str, Any] = { + "query": query, + "biomarkers": biomarkers, + "patient_context": patient_context, + "errors": [], + } + + span = None + try: + if self._context.tracer: + span = self._context.tracer.start_span( + name="agentic_rag_ask", + metadata={"query": query}, + ) + result = self._graph.invoke(initial_state) + return result + except Exception as exc: + logger.error("Agentic RAG pipeline failed: %s", exc) + return { + **initial_state, + "final_answer": ( + "I apologize, but I'm temporarily unable to process your request. " + "Please consult a healthcare professional." + ), + "errors": [str(exc)], + } + finally: + if span is not None: + self._context.tracer.end_span(span) diff --git a/src/services/agents/context.py b/src/services/agents/context.py new file mode 100644 index 0000000000000000000000000000000000000000..83058d0c36b7056e6f1341b49b614406b369f8f3 --- /dev/null +++ b/src/services/agents/context.py @@ -0,0 +1,23 @@ +""" +MediGuard AI β€” Agentic RAG Context + +Runtime dependency injection dataclass β€” passed to every LangGraph node +so nodes can access services without globals. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass(frozen=True) +class AgenticContext: + """Immutable runtime context for agentic RAG nodes.""" + + llm: Any # LangChain chat model + embedding_service: Any # EmbeddingService + opensearch_client: Any # OpenSearchClient + cache: Any # RedisCache + tracer: Any # LangfuseTracer + guild: Optional[Any] = None # ClinicalInsightGuild (original workflow) diff --git a/src/services/agents/medical/__init__.py b/src/services/agents/medical/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e689127b711ba6347f08dc61aa9be4ea3681bed --- /dev/null +++ b/src/services/agents/medical/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” Medical agents (original 6 agents, re-exported).""" diff --git a/src/services/agents/nodes/__init__.py b/src/services/agents/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8035a8f90dea1d01bc65bb1220356b0eb9b097 --- /dev/null +++ b/src/services/agents/nodes/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” Agentic RAG nodes package.""" diff --git a/src/services/agents/nodes/generate_answer_node.py b/src/services/agents/nodes/generate_answer_node.py new file mode 100644 index 0000000000000000000000000000000000000000..ebda354daed31221ed87b1b7f1da5b2c11ab5276 --- /dev/null +++ b/src/services/agents/nodes/generate_answer_node.py @@ -0,0 +1,60 @@ +""" +MediGuard AI β€” Generate Answer Node + +Produces a RAG-grounded medical answer with citations. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from src.services.agents.prompts import RAG_GENERATION_SYSTEM + +logger = logging.getLogger(__name__) + + +def generate_answer_node(state: dict, *, context: Any) -> dict: + """Generate a cited medical answer from relevant documents.""" + query = state.get("rewritten_query") or state.get("query", "") + documents = state.get("relevant_documents", []) + biomarkers = state.get("biomarkers") + patient_context = state.get("patient_context", "") + + # Build evidence block + evidence_parts: list[str] = [] + for i, doc in enumerate(documents, 1): + title = doc.get("title", "Unknown") + section = doc.get("section", "") + text = doc.get("text", "")[:2000] + header = f"[{i}] {title}" + if section: + header += f" β€” {section}" + evidence_parts.append(f"{header}\n{text}") + evidence_block = "\n\n---\n\n".join(evidence_parts) if evidence_parts else "(No evidence retrieved)" + + # Build user message + user_msg = f"Question: {query}\n\n" + if biomarkers: + user_msg += f"Biomarkers: {biomarkers}\n\n" + if patient_context: + user_msg += f"Patient context: {patient_context}\n\n" + user_msg += f"Evidence:\n{evidence_block}" + + try: + response = context.llm.invoke( + [ + {"role": "system", "content": RAG_GENERATION_SYSTEM}, + {"role": "user", "content": user_msg}, + ] + ) + answer = response.content.strip() + except Exception as exc: + logger.error("Generation LLM failed: %s", exc) + answer = ( + "I apologize, but I'm temporarily unable to generate a response. " + "Please consult a healthcare professional for guidance." + ) + return {"final_answer": answer, "errors": [str(exc)]} + + return {"final_answer": answer} diff --git a/src/services/agents/nodes/grade_documents_node.py b/src/services/agents/nodes/grade_documents_node.py new file mode 100644 index 0000000000000000000000000000000000000000..3c60881cb3cfe274405dc3f16f96ffe92e57d502 --- /dev/null +++ b/src/services/agents/nodes/grade_documents_node.py @@ -0,0 +1,64 @@ +""" +MediGuard AI β€” Grade Documents Node + +Uses the LLM to judge whether each retrieved document is relevant to the query. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from src.services.agents.prompts import GRADING_SYSTEM + +logger = logging.getLogger(__name__) + + +def grade_documents_node(state: dict, *, context: Any) -> dict: + """Grade each retrieved document for relevance.""" + query = state.get("rewritten_query") or state.get("query", "") + documents = state.get("retrieved_documents", []) + + if not documents: + return { + "grading_results": [], + "relevant_documents": [], + "needs_rewrite": True, + } + + relevant: list[dict] = [] + grading_results: list[dict] = [] + + for doc in documents: + text = doc.get("text", "") + user_msg = f"Query: {query}\n\nDocument:\n{text[:2000]}" + try: + response = context.llm.invoke( + [ + {"role": "system", "content": GRADING_SYSTEM}, + {"role": "user", "content": user_msg}, + ] + ) + content = response.content.strip() + if "```" in content: + content = content.split("```")[1].split("```")[0].strip() + if content.startswith("json"): + content = content[4:].strip() + data = json.loads(content) + is_relevant = str(data.get("relevant", "false")).lower() == "true" + except Exception as exc: + logger.warning("Grading LLM failed for doc %s: %s β€” marking relevant", doc.get("id"), exc) + is_relevant = True # benefit of the doubt + + grading_results.append({"doc_id": doc.get("id"), "relevant": is_relevant}) + if is_relevant: + relevant.append(doc) + + needs_rewrite = len(relevant) < 2 and not state.get("rewritten_query") + + return { + "grading_results": grading_results, + "relevant_documents": relevant, + "needs_rewrite": needs_rewrite, + } diff --git a/src/services/agents/nodes/guardrail_node.py b/src/services/agents/nodes/guardrail_node.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea7f71956432cf20962a8f6d80349a08d216081 --- /dev/null +++ b/src/services/agents/nodes/guardrail_node.py @@ -0,0 +1,57 @@ +""" +MediGuard AI β€” Guardrail Node + +Validates that the user query is within the medical domain (score 0-100). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from src.services.agents.prompts import GUARDRAIL_SYSTEM + +logger = logging.getLogger(__name__) + + +def guardrail_node(state: dict, *, context: Any) -> dict: + """Score the query for medical relevance (0-100).""" + query = state.get("query", "") + biomarkers = state.get("biomarkers") + + # Fast path: if biomarkers are provided, it's definitely medical + if biomarkers: + return { + "guardrail_score": 95.0, + "is_in_scope": True, + "routing_decision": "analyze", + } + + try: + response = context.llm.invoke( + [ + {"role": "system", "content": GUARDRAIL_SYSTEM}, + {"role": "user", "content": query}, + ] + ) + content = response.content.strip() + # Parse JSON response + if "```" in content: + content = content.split("```")[1].split("```")[0].strip() + if content.startswith("json"): + content = content[4:].strip() + data = json.loads(content) + score = float(data.get("score", 0)) + except Exception as exc: + logger.warning("Guardrail LLM failed: %s β€” defaulting to in-scope", exc) + score = 70.0 # benefit of the doubt + + is_in_scope = score >= 40 + routing = "rag_answer" if is_in_scope else "out_of_scope" + + return { + "guardrail_score": score, + "is_in_scope": is_in_scope, + "routing_decision": routing, + } diff --git a/src/services/agents/nodes/out_of_scope_node.py b/src/services/agents/nodes/out_of_scope_node.py new file mode 100644 index 0000000000000000000000000000000000000000..63ce220cc5aa5273ca6539d829d97313bec5d0c3 --- /dev/null +++ b/src/services/agents/nodes/out_of_scope_node.py @@ -0,0 +1,16 @@ +""" +MediGuard AI β€” Out-of-Scope Node + +Returns a polite rejection for non-medical queries. +""" + +from __future__ import annotations + +from typing import Any + +from src.services.agents.prompts import OUT_OF_SCOPE_RESPONSE + + +def out_of_scope_node(state: dict, *, context: Any) -> dict: + """Return polite out-of-scope message.""" + return {"final_answer": OUT_OF_SCOPE_RESPONSE} diff --git a/src/services/agents/nodes/retrieve_node.py b/src/services/agents/nodes/retrieve_node.py new file mode 100644 index 0000000000000000000000000000000000000000..4b47cc36711ec085999fcf052ca6e049e3f5fb91 --- /dev/null +++ b/src/services/agents/nodes/retrieve_node.py @@ -0,0 +1,68 @@ +""" +MediGuard AI β€” Retrieve Node + +Performs hybrid search (BM25 + vector KNN) and merges results. +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def retrieve_node(state: dict, *, context: Any) -> dict: + """Retrieve documents from OpenSearch via hybrid search.""" + query = state.get("rewritten_query") or state.get("query", "") + + # 1. Try cache first + cache_key = f"retrieve:{query}" + if context.cache: + cached = context.cache.get(cache_key) + if cached is not None: + logger.debug("Cache hit for retrieve query") + return {"retrieved_documents": cached} + + # 2. Embed the query + try: + query_embedding = context.embedding_service.embed_query(query) + except Exception as exc: + logger.error("Embedding failed: %s", exc) + return {"retrieved_documents": [], "errors": [str(exc)]} + + # 3. Hybrid search + try: + results = context.opensearch_client.search_hybrid( + query_text=query, + query_vector=query_embedding, + top_k=10, + ) + except Exception as exc: + logger.error("OpenSearch hybrid search failed: %s β€” falling back to BM25", exc) + try: + results = context.opensearch_client.search_bm25( + query_text=query, + top_k=10, + ) + except Exception as exc2: + logger.error("BM25 fallback also failed: %s", exc2) + return {"retrieved_documents": [], "errors": [str(exc), str(exc2)]} + + documents = [ + { + "id": hit.get("_id", ""), + "score": hit.get("_score", 0.0), + "text": hit.get("_source", {}).get("chunk_text", ""), + "title": hit.get("_source", {}).get("title", ""), + "section": hit.get("_source", {}).get("section_title", ""), + "metadata": hit.get("_source", {}), + } + for hit in results + ] + + # 4. Store in cache (5 min TTL) + if context.cache: + context.cache.set(cache_key, documents, ttl=300) + + return {"retrieved_documents": documents} diff --git a/src/services/agents/nodes/rewrite_query_node.py b/src/services/agents/nodes/rewrite_query_node.py new file mode 100644 index 0000000000000000000000000000000000000000..71bd4c913b3ff23369dc1b974d499f4e534914f8 --- /dev/null +++ b/src/services/agents/nodes/rewrite_query_node.py @@ -0,0 +1,40 @@ +""" +MediGuard AI β€” Rewrite Query Node + +Reformulates the user query to improve retrieval recall. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from src.services.agents.prompts import REWRITE_SYSTEM + +logger = logging.getLogger(__name__) + + +def rewrite_query_node(state: dict, *, context: Any) -> dict: + """Rewrite the original query for better retrieval.""" + original = state.get("query", "") + patient_context = state.get("patient_context", "") + + user_msg = f"Original query: {original}" + if patient_context: + user_msg += f"\n\nPatient context: {patient_context}" + + try: + response = context.llm.invoke( + [ + {"role": "system", "content": REWRITE_SYSTEM}, + {"role": "user", "content": user_msg}, + ] + ) + rewritten = response.content.strip() + if not rewritten: + rewritten = original + except Exception as exc: + logger.warning("Rewrite LLM failed: %s β€” keeping original query", exc) + rewritten = original + + return {"rewritten_query": rewritten} diff --git a/src/services/agents/prompts.py b/src/services/agents/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6fd24e47f9a97d023543a0ce01553943bf59d0 --- /dev/null +++ b/src/services/agents/prompts.py @@ -0,0 +1,72 @@ +""" +MediGuard AI β€” Agentic RAG Prompts + +Medical-domain prompts for guardrail, grading, rewriting, and generation. +""" + +# ── Guardrail prompt ───────────────────────────────────────────────────────── + +GUARDRAIL_SYSTEM = """\ +You are a medical-domain classifier. Determine whether the user query is +about health, biomarkers, medical conditions, clinical guidelines, or +wellness β€” topics that MediGuard AI can help with. + +Score the query from 0 to 100: + 90-100 Clearly medical (biomarker values, disease questions, symptoms) + 60-89 Health-adjacent (nutrition, fitness, wellness) + 30-59 Loosely related (general biology, anatomy trivia) + 0-29 Not medical at all (weather, coding, sports) + +Respond ONLY with JSON: +{{"score": , "reason": ""}} +""" + +# ── Document grading prompt ────────────────────────────────────────────────── + +GRADING_SYSTEM = """\ +You are a medical-relevance grader. Given a user question and a retrieved +document chunk, decide whether the document is relevant to answering the +medical question. + +Respond ONLY with JSON: +{{"relevant": true/false, "reason": ""}} +""" + +# ── Query rewriting prompt ─────────────────────────────────────────────────── + +REWRITE_SYSTEM = """\ +You are a medical-query optimiser. The original user query did not +retrieve relevant medical documents. Rewrite it to improve retrieval from +a medical knowledge base. + +Guidelines: +- Use standard medical terminology +- Add synonyms for biomarker names +- Make the intent clearer + +Respond with ONLY the rewritten query (no explanation, no quotes). +""" + +# ── RAG generation prompt ──────────────────────────────────────────────────── + +RAG_GENERATION_SYSTEM = """\ +You are MediGuard AI, a clinical-information assistant. +Answer the user's medical question using ONLY the provided context documents. +If the context is insufficient, say so honestly. + +Rules: +1. Cite specific documents with [Source: filename, Page X]. +2. Use patient-friendly language. +3. Never provide a definitive diagnosis β€” use "may indicate", "suggests". +4. Always end with: "Please consult a healthcare professional for diagnosis." +5. If biomarker values are critical, highlight them as safety alerts. +""" + +# ── Out-of-scope response ─────────────────────────────────────────────────── + +OUT_OF_SCOPE_RESPONSE = ( + "I'm MediGuard AI β€” I specialise in medical biomarker analysis and " + "health-related questions. Your query doesn't appear to be about a " + "medical or health topic I can help with. Please try asking about " + "biomarker values, disease information, or clinical guidelines." +) diff --git a/src/services/agents/state.py b/src/services/agents/state.py new file mode 100644 index 0000000000000000000000000000000000000000..e87308c359bd0b5cd2592217a6d50f3bdd30a7d8 --- /dev/null +++ b/src/services/agents/state.py @@ -0,0 +1,47 @@ +""" +MediGuard AI β€” Agentic RAG State + +Enhanced LangGraph state for the guardrail β†’ retrieve β†’ grade β†’ generate +pipeline that wraps the existing 6-agent clinical workflow. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Annotated +from typing_extensions import TypedDict +import operator + + +class AgenticRAGState(TypedDict): + """State flowing through the agentic RAG graph.""" + + # ── Input ──────────────────────────────────────────────────────────── + query: str + biomarkers: Optional[Dict[str, float]] + patient_context: Optional[Dict[str, Any]] + + # ── Guardrail ──────────────────────────────────────────────────────── + guardrail_score: float # 0-100 medical-relevance score + is_in_scope: bool # passed guardrail? + + # ── Retrieval ──────────────────────────────────────────────────────── + retrieved_documents: List[Dict[str, Any]] + retrieval_attempts: int + max_retrieval_attempts: int + + # ── Grading ────────────────────────────────────────────────────────── + grading_results: List[Dict[str, Any]] + relevant_documents: List[Dict[str, Any]] + needs_rewrite: bool + + # ── Rewriting ──────────────────────────────────────────────────────── + rewritten_query: Optional[str] + + # ── Generation / routing ───────────────────────────────────────────── + routing_decision: str # "analyze" | "rag_answer" | "out_of_scope" + final_answer: Optional[str] + analysis_result: Optional[Dict[str, Any]] + + # ── Metadata ───────────────────────────────────────────────────────── + trace_id: Optional[str] + errors: Annotated[List[str], operator.add] diff --git a/src/services/biomarker/__init__.py b/src/services/biomarker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6a6916e0528a2270c0da43ec30242d61d4f603 --- /dev/null +++ b/src/services/biomarker/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” Biomarker validation service.""" diff --git a/src/services/biomarker/service.py b/src/services/biomarker/service.py new file mode 100644 index 0000000000000000000000000000000000000000..84dd76d215459140c51e263335c7fd7742a7de02 --- /dev/null +++ b/src/services/biomarker/service.py @@ -0,0 +1,110 @@ +""" +MediGuard AI β€” Biomarker Validation Service + +Wraps the existing BiomarkerValidator as a production service with caching, +observability, and Pydantic-typed outputs. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Any, Dict, List, Optional + +from src.biomarker_validator import BiomarkerValidator +from src.biomarker_normalization import normalize_biomarker_name +from src.settings import get_settings + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BiomarkerResult: + """Validated result for a single biomarker.""" + + name: str + value: float + unit: str + status: str # NORMAL | HIGH | LOW | CRITICAL_HIGH | CRITICAL_LOW + reference_range: str + warning: Optional[str] = None + + +@dataclass +class ValidationReport: + """Complete biomarker validation report.""" + + results: List[BiomarkerResult] = field(default_factory=list) + safety_alerts: List[Dict[str, Any]] = field(default_factory=list) + recognized_count: int = 0 + unrecognized: List[str] = field(default_factory=list) + + +class BiomarkerService: + """Production biomarker validation service.""" + + def __init__(self) -> None: + self._validator = BiomarkerValidator() + + # --------------------------------------------------------------------- # + # Public API + # --------------------------------------------------------------------- # + + def validate( + self, + biomarkers: Dict[str, float], + gender: Optional[str] = None, + ) -> ValidationReport: + """Validate a dict of biomarker name β†’ value and return a report.""" + report = ValidationReport() + + for raw_name, value in biomarkers.items(): + normalized = normalize_biomarker_name(raw_name) + flag = self._validator.validate_biomarker(normalized, value, gender=gender) + if flag is None: + report.unrecognized.append(raw_name) + continue + if flag.status == "UNKNOWN": + report.unrecognized.append(raw_name) + continue + report.recognized_count += 1 + report.results.append( + BiomarkerResult( + name=flag.name, + value=flag.value, + unit=flag.unit, + status=flag.status, + reference_range=flag.reference_range, + warning=flag.warning, + ) + ) + if flag.status.startswith("CRITICAL"): + report.safety_alerts.append( + { + "severity": "CRITICAL", + "biomarker": normalized, + "message": flag.warning or f"{normalized} is critically out of range", + "action": "Seek immediate medical attention", + } + ) + + return report + + def list_supported(self) -> List[Dict[str, Any]]: + """Return metadata for all supported biomarkers.""" + result = [] + for name, ref in self._validator.references.items(): + result.append({ + "name": name, + "unit": ref.get("unit", ""), + "normal_range": ref.get("normal_range", {}), + "critical_low": ref.get("critical_low"), + "critical_high": ref.get("critical_high"), + }) + return result + + +@lru_cache(maxsize=1) +def make_biomarker_service() -> BiomarkerService: + return BiomarkerService() diff --git a/src/services/cache/__init__.py b/src/services/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f3ff8596b70e870650f00263ee8e511da34320 --- /dev/null +++ b/src/services/cache/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI β€” Redis cache service package.""" +from src.services.cache.redis_cache import RedisCache, make_redis_cache + +__all__ = ["RedisCache", "make_redis_cache"] diff --git a/src/services/cache/redis_cache.py b/src/services/cache/redis_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d936062cd88fa2cb39ed70730e9c800022b99f --- /dev/null +++ b/src/services/cache/redis_cache.py @@ -0,0 +1,123 @@ +""" +MediGuard AI β€” Redis Cache + +Exact-match caching with SHA-256 keys for RAG and analysis responses. +Gracefully degrades when Redis is unavailable. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from functools import lru_cache +from typing import Any, Dict, Optional + +from src.settings import get_settings + +logger = logging.getLogger(__name__) + +try: + import redis as _redis +except ImportError: # pragma: no cover + _redis = None # type: ignore[assignment] + + +class RedisCache: + """Thin Redis wrapper with SHA-256 key generation and JSON ser/de.""" + + def __init__(self, client: Any, default_ttl: int = 21600): + self._client = client + self._default_ttl = default_ttl + self._enabled = client is not None + + @property + def enabled(self) -> bool: + return self._enabled + + def ping(self) -> bool: + if not self._enabled: + return False + try: + return self._client.ping() + except Exception: + return False + + @staticmethod + def _make_key(*parts: str) -> str: + raw = "|".join(parts) + return f"mediguard:{hashlib.sha256(raw.encode()).hexdigest()}" + + def get(self, *key_parts: str) -> Optional[Dict[str, Any]]: + if not self._enabled: + return None + key = self._make_key(*key_parts) + try: + value = self._client.get(key) + if value is None: + return None + return json.loads(value) + except Exception as exc: + logger.warning("Cache GET failed: %s", exc) + return None + + def set(self, value: Dict[str, Any], *key_parts: str, ttl: Optional[int] = None) -> bool: + if not self._enabled: + return False + key = self._make_key(*key_parts) + try: + self._client.setex(key, ttl or self._default_ttl, json.dumps(value, default=str)) + return True + except Exception as exc: + logger.warning("Cache SET failed: %s", exc) + return False + + def delete(self, *key_parts: str) -> bool: + if not self._enabled: + return False + key = self._make_key(*key_parts) + try: + self._client.delete(key) + return True + except Exception as exc: + logger.warning("Cache DELETE failed: %s", exc) + return False + + def flush(self) -> bool: + if not self._enabled: + return False + try: + self._client.flushdb() + return True + except Exception: + return False + + +class _NullCache(RedisCache): + """No-op cache returned when Redis is disabled or unavailable.""" + + def __init__(self): + super().__init__(client=None) + + +@lru_cache(maxsize=1) +def make_redis_cache() -> RedisCache: + """Factory β€” returns a live cache or a silent null-cache.""" + settings = get_settings() + if not settings.redis.enabled or _redis is None: + logger.info("Redis caching disabled") + return _NullCache() + try: + client = _redis.Redis( + host=settings.redis.host, + port=settings.redis.port, + db=settings.redis.db, + decode_responses=True, + socket_connect_timeout=3, + ) + client.ping() + logger.info("Redis connected (%s:%d)", settings.redis.host, settings.redis.port) + return RedisCache(client, settings.redis.ttl_seconds) + except Exception as exc: + logger.warning("Redis unavailable (%s), running without cache", exc) + return _NullCache() diff --git a/src/services/embeddings/__init__.py b/src/services/embeddings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a90f1ee3fbdc37f5fbf4fdfbc9865123bcb05437 --- /dev/null +++ b/src/services/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI β€” Embeddings service package.""" +from src.services.embeddings.service import EmbeddingService, make_embedding_service + +__all__ = ["EmbeddingService", "make_embedding_service"] diff --git a/src/services/embeddings/service.py b/src/services/embeddings/service.py new file mode 100644 index 0000000000000000000000000000000000000000..13c5ea8f3efe20c387099b091427966988849331 --- /dev/null +++ b/src/services/embeddings/service.py @@ -0,0 +1,147 @@ +""" +MediGuard AI β€” Embedding Service + +Supports Jina AI, Google, HuggingFace, and Ollama embeddings with +automatic fallback chain: Jina β†’ Google β†’ HuggingFace. +""" + +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import List + +from src.exceptions import EmbeddingError, EmbeddingProviderError +from src.settings import get_settings + +logger = logging.getLogger(__name__) + + +class EmbeddingService: + """Unified embedding interface β€” delegates to the configured provider.""" + + def __init__(self, model, provider_name: str, dimension: int): + self._model = model + self.provider_name = provider_name + self.dimension = dimension + + def embed_query(self, text: str) -> List[float]: + """Embed a single query text.""" + try: + return self._model.embed_query(text) + except Exception as exc: + raise EmbeddingProviderError(f"{self.provider_name} embed_query failed: {exc}") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Batch-embed a list of texts.""" + try: + return self._model.embed_documents(texts) + except Exception as exc: + raise EmbeddingProviderError(f"{self.provider_name} embed_documents failed: {exc}") + + +def _make_google_embeddings(): + settings = get_settings() + api_key = settings.embedding.google_api_key or settings.llm.google_api_key + if not api_key: + raise EmbeddingError("GOOGLE_API_KEY not set for Google embeddings") + from langchain_google_genai import GoogleGenerativeAIEmbeddings + + model = GoogleGenerativeAIEmbeddings( + model="models/text-embedding-004", + google_api_key=api_key, + ) + return EmbeddingService(model, "google", 768) + + +def _make_huggingface_embeddings(): + settings = get_settings() + try: + from langchain_huggingface import HuggingFaceEmbeddings + except ImportError: + from langchain_community.embeddings import HuggingFaceEmbeddings + + model = HuggingFaceEmbeddings(model_name=settings.embedding.huggingface_model) + return EmbeddingService(model, "huggingface", 384) + + +def _make_ollama_embeddings(): + settings = get_settings() + try: + from langchain_ollama import OllamaEmbeddings + except ImportError: + from langchain_community.embeddings import OllamaEmbeddings + + model = OllamaEmbeddings( + model=settings.ollama.embedding_model, + base_url=settings.ollama.host, + ) + return EmbeddingService(model, "ollama", 768) + + +def _make_jina_embeddings(): + settings = get_settings() + api_key = settings.embedding.jina_api_key + if not api_key: + raise EmbeddingError("JINA_API_KEY not set for Jina embeddings") + # Jina v3 via httpx (lightweight, no extra SDK) + import httpx + + class _JinaModel: + """Minimal Jina AI embedding adapter.""" + + def __init__(self, api_key: str, model: str): + self._api_key = api_key + self._model = model + self._url = "https://api.jina.ai/v1/embeddings" + + def _call(self, texts: list[str], task: str = "retrieval.passage") -> list[list[float]]: + headers = {"Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json"} + payload = {"model": self._model, "input": texts, "task": task} + resp = httpx.post(self._url, json=payload, headers=headers, timeout=60) + resp.raise_for_status() + data = resp.json()["data"] + return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])] + + def embed_query(self, text: str) -> list[float]: + return self._call([text], task="retrieval.query")[0] + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self._call(texts, task="retrieval.passage") + + model = _JinaModel(api_key, settings.embedding.jina_model) + return EmbeddingService(model, "jina", settings.embedding.dimension) + + +# ── Fallback chain factory ─────────────────────────────────────────────────── + +_PROVIDERS = { + "jina": _make_jina_embeddings, + "google": _make_google_embeddings, + "huggingface": _make_huggingface_embeddings, + "ollama": _make_ollama_embeddings, +} + +FALLBACK_ORDER = ["jina", "google", "huggingface"] + + +@lru_cache(maxsize=1) +def make_embedding_service() -> EmbeddingService: + """Create an embedding service with automatic fallback.""" + settings = get_settings() + preferred = settings.embedding.provider + + # Try preferred first, then fallbacks + order = [preferred] + [p for p in FALLBACK_ORDER if p != preferred] + for provider in order: + factory = _PROVIDERS.get(provider) + if factory is None: + continue + try: + svc = factory() + logger.info("Embedding provider: %s (dim=%d)", svc.provider_name, svc.dimension) + return svc + except Exception as exc: + logger.warning("Embedding provider '%s' failed: %s β€” trying next", provider, exc) + + raise EmbeddingError("All embedding providers failed. Check your API keys and configuration.") diff --git a/src/services/indexing/__init__.py b/src/services/indexing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b82f2806fdcad81850df40c1ac84f0fca2009ac --- /dev/null +++ b/src/services/indexing/__init__.py @@ -0,0 +1,5 @@ +"""MediGuard AI β€” Indexing (chunking + embedding + OpenSearch) package.""" +from src.services.indexing.text_chunker import MedicalTextChunker +from src.services.indexing.service import IndexingService + +__all__ = ["MedicalTextChunker", "IndexingService"] diff --git a/src/services/indexing/service.py b/src/services/indexing/service.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6af87e6b9a4bd9aa486c021526a345797f63f3 --- /dev/null +++ b/src/services/indexing/service.py @@ -0,0 +1,84 @@ +""" +MediGuard AI β€” Indexing Service + +Orchestrates: PDF parse β†’ chunk β†’ embed β†’ index into OpenSearch. +""" + +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timezone +from typing import Dict, List + +from src.services.indexing.text_chunker import MedicalChunk, MedicalTextChunker + +logger = logging.getLogger(__name__) + + +class IndexingService: + """Coordinates chunking β†’ embedding β†’ OpenSearch indexing.""" + + def __init__(self, chunker, embedding_service, opensearch_client): + self.chunker = chunker + self.embedding_service = embedding_service + self.opensearch_client = opensearch_client + + def index_text( + self, + text: str, + *, + document_id: str = "", + title: str = "", + source_file: str = "", + ) -> int: + """Chunk, embed, and index a single document's text. Returns count of indexed chunks.""" + if not document_id: + document_id = str(uuid.uuid4()) + + chunks = self.chunker.chunk_text( + text, + document_id=document_id, + title=title, + source_file=source_file, + ) + if not chunks: + logger.warning("No chunks generated for document '%s'", title) + return 0 + + # Embed all chunks + texts = [c.text for c in chunks] + embeddings = self.embedding_service.embed_documents(texts) + + # Prepare OpenSearch documents + now = datetime.now(timezone.utc).isoformat() + docs: List[Dict] = [] + for chunk, emb in zip(chunks, embeddings): + doc = chunk.to_dict() + doc["_id"] = f"{document_id}_{chunk.chunk_index}" + doc["embedding"] = emb + doc["indexed_at"] = now + docs.append(doc) + + indexed = self.opensearch_client.bulk_index(docs) + logger.info( + "Indexed %d chunks for '%s' (document_id=%s)", + indexed, title, document_id, + ) + return indexed + + def index_chunks(self, chunks: List[MedicalChunk]) -> int: + """Embed and index pre-built chunks.""" + if not chunks: + return 0 + texts = [c.text for c in chunks] + embeddings = self.embedding_service.embed_documents(texts) + now = datetime.now(timezone.utc).isoformat() + docs: List[Dict] = [] + for chunk, emb in zip(chunks, embeddings): + doc = chunk.to_dict() + doc["_id"] = f"{chunk.document_id}_{chunk.chunk_index}" + doc["embedding"] = emb + doc["indexed_at"] = now + docs.append(doc) + return self.opensearch_client.bulk_index(docs) diff --git a/src/services/indexing/text_chunker.py b/src/services/indexing/text_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..9710214b16747e84ad4738551e72d35f11340490 --- /dev/null +++ b/src/services/indexing/text_chunker.py @@ -0,0 +1,178 @@ +""" +MediGuard AI β€” Medical-Aware Text Chunker + +Section-aware chunking with biomarker / condition metadata extraction. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set + +# Biomarker names to detect in chunk text +_BIOMARKER_NAMES: Set[str] = { + "Glucose", "Cholesterol", "Triglycerides", "HbA1c", "LDL", "HDL", + "Insulin", "BMI", "Hemoglobin", "Platelets", "WBC", "RBC", + "Hematocrit", "MCV", "MCH", "MCHC", "Heart Rate", "Systolic", + "Diastolic", "Troponin", "CRP", "C-reactive Protein", "ALT", "AST", + "Creatinine", "TSH", "T3", "T4", "Sodium", "Potassium", "Calcium", +} + +_CONDITION_KEYWORDS: Dict[str, str] = { + "diabetes": "diabetes", + "diabetic": "diabetes", + "hyperglycemia": "diabetes", + "insulin resistance": "diabetes", + "anemia": "anemia", + "anaemia": "anemia", + "iron deficiency": "anemia", + "thalassemia": "thalassemia", + "thalassaemia": "thalassemia", + "thrombocytopenia": "thrombocytopenia", + "heart disease": "heart_disease", + "cardiovascular": "heart_disease", + "coronary": "heart_disease", + "hypertension": "heart_disease", + "atherosclerosis": "heart_disease", + "hyperlipidemia": "heart_disease", +} + +_SECTION_RE = re.compile( + r"^(?:#+\s*)?(" + r"abstract|introduction|background|methods?|methodology|materials?" + r"|results?|findings|discussion|conclusion|summary" + r"|guidelines?|recommendations?|references?|bibliography" + r"|clinical\s*presentation|pathophysiology|diagnosis|treatment|prognosis" + r")\b", + re.IGNORECASE | re.MULTILINE, +) + + +@dataclass +class MedicalChunk: + """A single chunk with medical metadata.""" + text: str + chunk_index: int + document_id: str = "" + title: str = "" + source_file: str = "" + page_number: Optional[int] = None + section_title: str = "" + biomarkers_mentioned: List[str] = field(default_factory=list) + condition_tags: List[str] = field(default_factory=list) + word_count: int = 0 + + def to_dict(self) -> Dict: + return { + "chunk_text": self.text, + "chunk_index": self.chunk_index, + "document_id": self.document_id, + "title": self.title, + "source_file": self.source_file, + "page_number": self.page_number, + "section_title": self.section_title, + "biomarkers_mentioned": self.biomarkers_mentioned, + "condition_tags": self.condition_tags, + } + + +class MedicalTextChunker: + """Section-aware text chunker optimised for medical documents.""" + + def __init__( + self, + target_words: int = 600, + overlap_words: int = 100, + min_words: int = 50, + ): + self.target_words = target_words + self.overlap_words = overlap_words + self.min_words = min_words + + def chunk_text( + self, + text: str, + *, + document_id: str = "", + title: str = "", + source_file: str = "", + ) -> List[MedicalChunk]: + """Split text into enriched medical chunks.""" + sections = self._split_sections(text) + chunks: List[MedicalChunk] = [] + idx = 0 + for section_title, section_text in sections: + words = section_text.split() + if not words: + continue + start = 0 + while start < len(words): + end = min(start + self.target_words, len(words)) + chunk_words = words[start:end] + if len(chunk_words) < self.min_words and chunks: + # merge tiny tail into previous chunk + chunks[-1].text += " " + " ".join(chunk_words) + chunks[-1].word_count = len(chunks[-1].text.split()) + break + + chunk_text = " ".join(chunk_words) + biomarkers = self._detect_biomarkers(chunk_text) + conditions = self._detect_conditions(chunk_text) + + chunks.append( + MedicalChunk( + text=chunk_text, + chunk_index=idx, + document_id=document_id, + title=title, + source_file=source_file, + section_title=section_title, + biomarkers_mentioned=biomarkers, + condition_tags=conditions, + word_count=len(chunk_words), + ) + ) + idx += 1 + start = end - self.overlap_words if end < len(words) else len(words) + return chunks + + # ── internal helpers ───────────────────────────────────────────────── + + @staticmethod + def _split_sections(text: str) -> List[tuple[str, str]]: + """Split text by detected section headers.""" + matches = list(_SECTION_RE.finditer(text)) + if not matches: + return [("", text)] + sections: List[tuple[str, str]] = [] + # text before first section header + if matches[0].start() > 0: + preamble = text[: matches[0].start()].strip() + if preamble: + sections.append(("", preamble)) + for i, match in enumerate(matches): + header = match.group(1).strip().title() + start = match.end() + end = matches[i + 1].start() if i + 1 < len(matches) else len(text) + body = text[start:end].strip() + # Skip reference/bibliography sections + if header.lower() in ("references", "bibliography"): + continue + if body: + sections.append((header, body)) + return sections or [("", text)] + + @staticmethod + def _detect_biomarkers(text: str) -> List[str]: + text_lower = text.lower() + return sorted( + {name for name in _BIOMARKER_NAMES if name.lower() in text_lower} + ) + + @staticmethod + def _detect_conditions(text: str) -> List[str]: + text_lower = text.lower() + return sorted( + {tag for kw, tag in _CONDITION_KEYWORDS.items() if kw in text_lower} + ) diff --git a/src/services/langfuse/__init__.py b/src/services/langfuse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe206901714756d101b34827ac2e667f1fd15b4 --- /dev/null +++ b/src/services/langfuse/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI β€” Langfuse observability package.""" +from src.services.langfuse.tracer import LangfuseTracer, make_langfuse_tracer + +__all__ = ["LangfuseTracer", "make_langfuse_tracer"] diff --git a/src/services/langfuse/tracer.py b/src/services/langfuse/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0b9723a9a9b8b8debbe68d43c62be80e0f9188 --- /dev/null +++ b/src/services/langfuse/tracer.py @@ -0,0 +1,97 @@ +""" +MediGuard AI β€” Langfuse Observability Tracer + +Wraps Langfuse v3 SDK for end-to-end tracing of the RAG pipeline. +Silently no-ops when Langfuse is disabled or unreachable. +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from functools import lru_cache +from typing import Any, Dict, Optional + +from src.settings import get_settings + +logger = logging.getLogger(__name__) + +try: + from langfuse import Langfuse as _Langfuse +except ImportError: + _Langfuse = None # type: ignore[assignment,misc] + + +class LangfuseTracer: + """Thin wrapper around Langfuse for MediGuard pipeline tracing.""" + + def __init__(self, client: Any | None): + self._client = client + self._enabled = client is not None + + @property + def enabled(self) -> bool: + return self._enabled + + def trace(self, name: str, **kwargs: Any): + """Create a new trace (top-level span).""" + if not self._enabled: + return _NullSpan() + return self._client.trace(name=name, **kwargs) + + @contextmanager + def span(self, trace, name: str, **kwargs): + """Context manager for creating a span within a trace.""" + if not self._enabled or trace is None: + yield _NullSpan() + return + s = trace.span(name=name, **kwargs) + try: + yield s + finally: + s.end() + + def score(self, trace_id: str, name: str, value: float, comment: str = ""): + """Attach a score to a trace (for evaluation feedback).""" + if not self._enabled: + return + try: + self._client.score(trace_id=trace_id, name=name, value=value, comment=comment) + except Exception as exc: + logger.warning("Langfuse score failed: %s", exc) + + def flush(self): + if self._enabled: + try: + self._client.flush() + except Exception: + pass + + +class _NullSpan: + """Dummy span object that silently swallows calls.""" + + def __getattr__(self, name: str): + return lambda *a, **kw: _NullSpan() + + def end(self): + pass + + +@lru_cache(maxsize=1) +def make_langfuse_tracer() -> LangfuseTracer: + settings = get_settings() + if not settings.langfuse.enabled or _Langfuse is None: + logger.info("Langfuse tracing disabled") + return LangfuseTracer(None) + try: + client = _Langfuse( + public_key=settings.langfuse.public_key, + secret_key=settings.langfuse.secret_key, + host=settings.langfuse.host, + ) + logger.info("Langfuse connected (%s)", settings.langfuse.host) + return LangfuseTracer(client) + except Exception as exc: + logger.warning("Langfuse unavailable: %s", exc) + return LangfuseTracer(None) diff --git a/src/services/ollama/__init__.py b/src/services/ollama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb83880824eec8a57e410fbc70521ff6efcd99bb --- /dev/null +++ b/src/services/ollama/__init__.py @@ -0,0 +1,4 @@ +"""MediGuard AI β€” Ollama client package.""" +from src.services.ollama.client import OllamaClient, make_ollama_client + +__all__ = ["OllamaClient", "make_ollama_client"] diff --git a/src/services/ollama/client.py b/src/services/ollama/client.py new file mode 100644 index 0000000000000000000000000000000000000000..fd99e74cc7dd9ea70d95308528e7a1552caa6f0a --- /dev/null +++ b/src/services/ollama/client.py @@ -0,0 +1,160 @@ +""" +MediGuard AI β€” Ollama Client + +Production-grade wrapper for the Ollama API with health checks, +streaming, and LangChain integration. +""" + +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import Any, Dict, Iterator, List, Optional + +import httpx + +from src.exceptions import OllamaConnectionError, OllamaModelNotFoundError +from src.settings import get_settings + +logger = logging.getLogger(__name__) + + +class OllamaClient: + """Wrapper around the Ollama REST API.""" + + def __init__(self, base_url: str, *, timeout: int = 120): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._http = httpx.Client(base_url=self.base_url, timeout=timeout) + + # ── Health ─────────────────────────────────────────────────────────── + + def ping(self) -> bool: + try: + resp = self._http.get("/api/version") + return resp.status_code == 200 + except Exception: + return False + + def health(self) -> Dict[str, Any]: + try: + resp = self._http.get("/api/version") + resp.raise_for_status() + return resp.json() + except Exception as exc: + raise OllamaConnectionError(f"Cannot reach Ollama: {exc}") + + def list_models(self) -> List[str]: + try: + resp = self._http.get("/api/tags") + resp.raise_for_status() + return [m["name"] for m in resp.json().get("models", [])] + except Exception as exc: + logger.warning("Failed to list Ollama models: %s", exc) + return [] + + # ── Generation ─────────────────────────────────────────────────────── + + def generate( + self, + prompt: str, + *, + model: Optional[str] = None, + system: str = "", + temperature: float = 0.0, + num_ctx: int = 8192, + ) -> Dict[str, Any]: + """Synchronous generation β€” returns the full response dict.""" + model = model or get_settings().ollama.model + payload: Dict[str, Any] = { + "model": model, + "prompt": prompt, + "stream": False, + "options": {"temperature": temperature, "num_ctx": num_ctx}, + } + if system: + payload["system"] = system + try: + resp = self._http.post("/api/generate", json=payload) + resp.raise_for_status() + return resp.json() + except httpx.HTTPStatusError as exc: + if exc.response.status_code == 404: + raise OllamaModelNotFoundError(f"Model '{model}' not found on Ollama server") + raise OllamaConnectionError(str(exc)) + except Exception as exc: + raise OllamaConnectionError(str(exc)) + + def generate_stream( + self, + prompt: str, + *, + model: Optional[str] = None, + system: str = "", + temperature: float = 0.0, + num_ctx: int = 8192, + ) -> Iterator[str]: + """Streaming generation β€” yields text tokens.""" + model = model or get_settings().ollama.model + payload: Dict[str, Any] = { + "model": model, + "prompt": prompt, + "stream": True, + "options": {"temperature": temperature, "num_ctx": num_ctx}, + } + if system: + payload["system"] = system + try: + with self._http.stream("POST", "/api/generate", json=payload) as resp: + resp.raise_for_status() + import json + for line in resp.iter_lines(): + if line: + data = json.loads(line) + token = data.get("response", "") + if token: + yield token + if data.get("done", False): + break + except Exception as exc: + raise OllamaConnectionError(str(exc)) + + # ── LangChain integration ──────────────────────────────────────────── + + def get_langchain_model( + self, + *, + model: Optional[str] = None, + temperature: float = 0.0, + json_mode: bool = False, + ): + """Return a LangChain ChatOllama instance.""" + model = model or get_settings().ollama.model + try: + from langchain_ollama import ChatOllama + except ImportError: + from langchain_community.chat_models import ChatOllama + + return ChatOllama( + model=model, + temperature=temperature, + base_url=self.base_url, + format="json" if json_mode else None, + ) + + def close(self): + self._http.close() + + +@lru_cache(maxsize=1) +def make_ollama_client() -> OllamaClient: + settings = get_settings() + client = OllamaClient( + base_url=settings.ollama.host, + timeout=settings.ollama.timeout, + ) + if client.ping(): + logger.info("Ollama connected at %s", settings.ollama.host) + else: + logger.warning("Ollama not reachable at %s", settings.ollama.host) + return client diff --git a/src/services/opensearch/__init__.py b/src/services/opensearch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c59b705cca78d8e1c05c61821226cb33d8fe0427 --- /dev/null +++ b/src/services/opensearch/__init__.py @@ -0,0 +1,5 @@ +"""MediGuard AI β€” OpenSearch service package.""" +from src.services.opensearch.client import OpenSearchClient, make_opensearch_client +from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING + +__all__ = ["OpenSearchClient", "make_opensearch_client", "MEDICAL_CHUNKS_MAPPING"] diff --git a/src/services/opensearch/client.py b/src/services/opensearch/client.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8b02bb41f2bab012dbafec8b30f603698df5df --- /dev/null +++ b/src/services/opensearch/client.py @@ -0,0 +1,224 @@ +""" +MediGuard AI β€” OpenSearch Client + +Production search-engine wrapper supporting BM25, vector (KNN), and +hybrid search with Reciprocal Rank Fusion (RRF). +""" + +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import Any, Dict, List, Optional + +from src.exceptions import IndexNotFoundError, SearchError, SearchQueryError +from src.settings import get_settings + +logger = logging.getLogger(__name__) + +# Guard import β€” opensearch-py is optional when running tests locally +try: + from opensearchpy import OpenSearch, RequestError, NotFoundError as OSNotFoundError +except ImportError: # pragma: no cover + OpenSearch = None # type: ignore[assignment,misc] + + +class OpenSearchClient: + """Thin wrapper around *opensearch-py* with medical-domain helpers.""" + + def __init__(self, client: "OpenSearch", index_name: str): + self._client = client + self.index_name = index_name + + # ── Health ─────────────────────────────────────────────────────────── + + def health(self) -> Dict[str, Any]: + return self._client.cluster.health() + + def ping(self) -> bool: + try: + return self._client.ping() + except Exception: + return False + + # ── Index management ───────────────────────────────────────────────── + + def ensure_index(self, mapping: Dict[str, Any]) -> None: + """Create the index if it doesn't already exist.""" + if not self._client.indices.exists(index=self.index_name): + self._client.indices.create(index=self.index_name, body=mapping) + logger.info("Created OpenSearch index '%s'", self.index_name) + else: + logger.info("OpenSearch index '%s' already exists", self.index_name) + + def delete_index(self) -> None: + if self._client.indices.exists(index=self.index_name): + self._client.indices.delete(index=self.index_name) + + def doc_count(self) -> int: + try: + resp = self._client.count(index=self.index_name) + return resp["count"] + except Exception: + return 0 + + # ── Indexing ───────────────────────────────────────────────────────── + + def index_document(self, doc_id: str, body: Dict[str, Any]) -> None: + self._client.index(index=self.index_name, id=doc_id, body=body) + + def bulk_index(self, documents: List[Dict[str, Any]]) -> int: + """Bulk-index a list of dicts, each must have an ``_id`` key.""" + if not documents: + return 0 + actions: list[Dict[str, Any]] = [] + for doc in documents: + doc_id = doc.pop("_id", None) + actions.append({"index": {"_index": self.index_name, "_id": doc_id}}) + actions.append(doc) + resp = self._client.bulk(body=actions, refresh="wait_for") + indexed = sum(1 for item in resp.get("items", []) if item.get("index", {}).get("status") in (200, 201)) + logger.info("Bulk-indexed %d / %d documents", indexed, len(documents)) + return indexed + + # ── BM25 search ────────────────────────────────────────────────────── + + def search_bm25( + self, + query: str, + *, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + body: Dict[str, Any] = { + "size": top_k, + "query": { + "bool": { + "must": [ + { + "multi_match": { + "query": query, + "fields": [ + "chunk_text^3", + "title^2", + "section_title^1.5", + "abstract^1", + ], + "type": "best_fields", + } + } + ] + } + }, + } + if filters: + body["query"]["bool"]["filter"] = self._build_filters(filters) + return self._execute_search(body) + + # ── Vector (KNN) search ────────────────────────────────────────────── + + def search_vector( + self, + embedding: List[float], + *, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + body: Dict[str, Any] = { + "size": top_k, + "query": { + "knn": { + "embedding": { + "vector": embedding, + "k": top_k, + } + } + }, + } + return self._execute_search(body) + + # ── Hybrid search (RRF) ───────────────────────────────────────────── + + def search_hybrid( + self, + query: str, + embedding: List[float], + *, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + bm25_weight: float = 0.4, + vector_weight: float = 0.6, + ) -> List[Dict[str, Any]]: + """Reciprocal Rank Fusion of BM25 + KNN results.""" + bm25_results = self.search_bm25(query, top_k=top_k, filters=filters) + vector_results = self.search_vector(embedding, top_k=top_k, filters=filters) + return self._rrf_fuse(bm25_results, vector_results, top_k=top_k) + + # ── Internal helpers ───────────────────────────────────────────────── + + def _execute_search(self, body: Dict[str, Any]) -> List[Dict[str, Any]]: + try: + resp = self._client.search(index=self.index_name, body=body) + except Exception as exc: + raise SearchQueryError(str(exc)) + hits = resp.get("hits", {}).get("hits", []) + return [ + { + "_id": h["_id"], + "_score": h.get("_score", 0.0), + **h.get("_source", {}), + } + for h in hits + ] + + @staticmethod + def _build_filters(filters: Dict[str, Any]) -> List[Dict[str, Any]]: + clauses: List[Dict[str, Any]] = [] + for key, value in filters.items(): + if isinstance(value, list): + clauses.append({"terms": {key: value}}) + else: + clauses.append({"term": {key: value}}) + return clauses + + @staticmethod + def _rrf_fuse( + results_a: List[Dict[str, Any]], + results_b: List[Dict[str, Any]], + *, + k: int = 60, + top_k: int = 10, + ) -> List[Dict[str, Any]]: + """Simple Reciprocal Rank Fusion.""" + scores: Dict[str, float] = {} + docs: Dict[str, Dict[str, Any]] = {} + for rank, doc in enumerate(results_a, 1): + doc_id = doc["_id"] + scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) + docs[doc_id] = doc + for rank, doc in enumerate(results_b, 1): + doc_id = doc["_id"] + scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) + docs[doc_id] = doc + ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] + return [ + {**docs[doc_id], "_score": score} + for doc_id, score in ranked + ] + + +# ── Factory ────────────────────────────────────────────────────────────────── + +@lru_cache(maxsize=1) +def make_opensearch_client() -> OpenSearchClient: + if OpenSearch is None: + raise SearchError("opensearch-py is not installed") + settings = get_settings() + os_settings = settings.opensearch + client = OpenSearch( + hosts=[os_settings.host], + http_auth=(os_settings.username, os_settings.password) if os_settings.username else None, + verify_certs=os_settings.verify_certs, + timeout=os_settings.timeout, + ) + return OpenSearchClient(client, os_settings.index_name) diff --git a/src/services/opensearch/index_config.py b/src/services/opensearch/index_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5bda139fb1dc1df579f9956048ac0dc980d1bf2f --- /dev/null +++ b/src/services/opensearch/index_config.py @@ -0,0 +1,88 @@ +""" +MediGuard AI β€” OpenSearch index mapping for medical document chunks. + +Includes a medical synonym analyzer and KNN vector field for hybrid search. +""" + +MEDICAL_CHUNKS_MAPPING: dict = { + "settings": { + "index": { + "knn": True, + "knn.algo_param.ef_search": 256, + }, + "number_of_shards": 1, + "number_of_replicas": 0, + "analysis": { + "filter": { + "medical_synonyms": { + "type": "synonym", + "synonyms": [ + "diabetes mellitus, DM, diabetes", + "HbA1c, glycated hemoglobin, hemoglobin A1c, A1c", + "glucose, blood sugar, blood glucose", + "LDL, low density lipoprotein, bad cholesterol", + "HDL, high density lipoprotein, good cholesterol", + "WBC, white blood cells, leukocytes", + "RBC, red blood cells, erythrocytes", + "MCV, mean corpuscular volume", + "BP, blood pressure", + "CRP, C-reactive protein", + "ALT, alanine aminotransferase, SGPT", + "AST, aspartate aminotransferase, SGOT", + "TSH, thyroid stimulating hormone", + "BMI, body mass index", + "anemia, anaemia", + "thrombocytopenia, low platelets", + "thalassemia, thalassaemia", + ], + } + }, + "analyzer": { + "medical_analyzer": { + "type": "custom", + "tokenizer": "standard", + "filter": [ + "lowercase", + "medical_synonyms", + "stop", + "snowball", + ], + } + }, + }, + }, + "mappings": { + "properties": { + # ── Text fields ──────────────────────────────────────── + "chunk_text": { + "type": "text", + "analyzer": "medical_analyzer", + }, + "title": {"type": "text", "analyzer": "medical_analyzer"}, + "section_title": {"type": "text"}, + "abstract": {"type": "text", "analyzer": "medical_analyzer"}, + # ── Keyword / filterable ─────────────────────────────── + "document_id": {"type": "keyword"}, + "document_type": {"type": "keyword"}, # guideline, research, reference + "condition_tags": {"type": "keyword"}, # diabetes, anemia, … + "biomarkers_mentioned": {"type": "keyword"}, # Glucose, HbA1c, … + "source_file": {"type": "keyword"}, + "page_number": {"type": "integer"}, + "chunk_index": {"type": "integer"}, + "publication_year": {"type": "integer"}, + # ── Vector (KNN) ─────────────────────────────────────── + "embedding": { + "type": "knn_vector", + "dimension": 1024, + "method": { + "name": "hnsw", + "space_type": "cosinesimil", + "engine": "nmslib", + "parameters": {"ef_construction": 256, "m": 48}, + }, + }, + # ── Timestamps ───────────────────────────────────────── + "indexed_at": {"type": "date"}, + } + }, +} diff --git a/src/services/pdf_parser/__init__.py b/src/services/pdf_parser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4214e5529685c7f600ccae795bf073b583ca9b1 --- /dev/null +++ b/src/services/pdf_parser/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” PDF parsing service.""" diff --git a/src/services/pdf_parser/service.py b/src/services/pdf_parser/service.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0bf49b9b01d2c8a758ff9bb6361fdb4f4e64c8 --- /dev/null +++ b/src/services/pdf_parser/service.py @@ -0,0 +1,162 @@ +""" +MediGuard AI β€” PDF Parser Service + +Production PDF parsing with Docling (preferred) falling back to PyPDF. +Returns structured text with section metadata. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass, field +from functools import lru_cache +from pathlib import Path +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class ParsedSection: + """One logical section extracted from a PDF.""" + + title: str + text: str + page_numbers: List[int] = field(default_factory=list) + + +@dataclass +class ParsedDocument: + """Result of parsing a single PDF.""" + + filename: str + content_hash: str + full_text: str + sections: List[ParsedSection] = field(default_factory=list) + page_count: int = 0 + error: Optional[str] = None + + +class PDFParserService: + """Unified PDF parsing with Docling β†’ PyPDF fallback.""" + + def __init__(self) -> None: + self._has_docling = self._check_docling() + + @staticmethod + def _check_docling() -> bool: + try: + import docling # noqa: F401 + return True + except ImportError: + logger.info("Docling not installed β€” using PyPDF fallback") + return False + + def parse(self, path: Path) -> ParsedDocument: + """Parse a PDF file and return structured text.""" + if not path.exists(): + return ParsedDocument( + filename=path.name, + content_hash="", + full_text="", + error=f"File not found: {path}", + ) + + content_hash = hashlib.sha256(path.read_bytes()).hexdigest() + + if self._has_docling: + return self._parse_with_docling(path, content_hash) + return self._parse_with_pypdf(path, content_hash) + + # ------------------------------------------------------------------ # + # Docling (preferred) + # ------------------------------------------------------------------ # + + def _parse_with_docling(self, path: Path, content_hash: str) -> ParsedDocument: + try: + from docling.document_converter import DocumentConverter + + converter = DocumentConverter() + result = converter.convert(str(path)) + doc = result.document + + sections: list[ParsedSection] = [] + full_parts: list[str] = [] + + for element in doc.iterate_items(): + text = element.text if hasattr(element, "text") else str(element) + if text.strip(): + full_parts.append(text.strip()) + sections.append( + ParsedSection( + title=getattr(element, "label", ""), + text=text.strip(), + ) + ) + + full_text = "\n\n".join(full_parts) + return ParsedDocument( + filename=path.name, + content_hash=content_hash, + full_text=full_text, + sections=sections, + page_count=getattr(doc, "num_pages", 0), + ) + except Exception as exc: + logger.warning("Docling failed for %s β€” falling back to PyPDF: %s", path.name, exc) + return self._parse_with_pypdf(path, content_hash) + + # ------------------------------------------------------------------ # + # PyPDF fallback + # ------------------------------------------------------------------ # + + def _parse_with_pypdf(self, path: Path, content_hash: str) -> ParsedDocument: + try: + from pypdf import PdfReader + + reader = PdfReader(str(path)) + pages_text: list[str] = [] + for i, page in enumerate(reader.pages): + text = page.extract_text() or "" + if text.strip(): + pages_text.append(text.strip()) + + full_text = "\n\n".join(pages_text) + sections = [ + ParsedSection(title=f"Page {i + 1}", text=t, page_numbers=[i + 1]) + for i, t in enumerate(pages_text) + ] + + return ParsedDocument( + filename=path.name, + content_hash=content_hash, + full_text=full_text, + sections=sections, + page_count=len(reader.pages), + ) + except Exception as exc: + logger.error("PyPDF failed for %s: %s", path.name, exc) + return ParsedDocument( + filename=path.name, + content_hash=content_hash, + full_text="", + error=str(exc), + ) + + # ------------------------------------------------------------------ # + # Batch + # ------------------------------------------------------------------ # + + def parse_directory(self, directory: Path) -> List[ParsedDocument]: + """Parse all PDFs in a directory.""" + results: list[ParsedDocument] = [] + for pdf_path in sorted(directory.glob("*.pdf")): + logger.info("Parsing %s …", pdf_path.name) + results.append(self.parse(pdf_path)) + return results + + +@lru_cache(maxsize=1) +def make_pdf_parser_service() -> PDFParserService: + return PDFParserService() diff --git a/src/services/telegram/__init__.py b/src/services/telegram/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..250bebc5e8ebb9373c73ee4e32a1dafef74788fd --- /dev/null +++ b/src/services/telegram/__init__.py @@ -0,0 +1 @@ +"""MediGuard AI β€” Telegram bot service.""" diff --git a/src/services/telegram/bot.py b/src/services/telegram/bot.py new file mode 100644 index 0000000000000000000000000000000000000000..01a69a5b00d0f0552fc969993016b9e7f1315adc --- /dev/null +++ b/src/services/telegram/bot.py @@ -0,0 +1,102 @@ +""" +MediGuard AI β€” Telegram Bot + +Lightweight Telegram bot that proxies user messages to the /ask endpoint. +Requires ``python-telegram-bot`` (installed via extras ``[telegram]``). +""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + +# Lazy import β€” only needed when the bot is actually started +_Application = None + + +def _get_telegram(): + global _Application + try: + from telegram import Update + from telegram.ext import Application, CommandHandler, MessageHandler, filters + _Application = Application + return Update, Application, CommandHandler, MessageHandler, filters + except ImportError: + raise ImportError( + "python-telegram-bot is required for the Telegram bot. " + "Install it with: pip install 'mediguard[telegram]' or pip install python-telegram-bot" + ) + + +class MediGuardTelegramBot: + """Telegram bot that wraps a ``requests`` call to the API ``/ask`` endpoint.""" + + def __init__( + self, + token: Optional[str] = None, + api_base_url: str = "http://localhost:8000", + ) -> None: + self._token = token or os.getenv("TELEGRAM_BOT_TOKEN", "") + self._api_base = api_base_url.rstrip("/") + + if not self._token: + raise ValueError("TELEGRAM_BOT_TOKEN is required") + + def run(self) -> None: + """Start the bot (blocking).""" + import httpx + + Update, Application, CommandHandler, MessageHandler, filters = _get_telegram() + + app = Application.builder().token(self._token).build() + + async def start_handler(update: Update, context) -> None: + await update.message.reply_text( + "Welcome to MediGuard AI! Send me a medical question or biomarker values " + "and I'll provide evidence-based insights.\n\n" + "Disclaimer: This is not a substitute for professional medical advice." + ) + + async def help_handler(update: Update, context) -> None: + await update.message.reply_text( + "Send me:\n" + "β€’ A medical question (e.g. 'What does high HbA1c mean?')\n" + "β€’ Biomarker values (e.g. 'My glucose is 180 and HbA1c 8.2')\n\n" + "I'll provide evidence-based analysis." + ) + + async def message_handler(update: Update, context) -> None: + user_text = update.message.text or "" + if not user_text.strip(): + return + + await update.message.reply_text("Analyzing… please wait.") + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{self._api_base}/ask", + json={"question": user_text}, + ) + resp.raise_for_status() + data = resp.json() + answer = data.get("answer", "Sorry, I could not generate an answer.") + except Exception as exc: + logger.error("Telegramβ†’API call failed: %s", exc) + answer = "Sorry, I'm having trouble processing your request right now." + + # Telegram max message = 4096 chars + if len(answer) > 4000: + answer = answer[:4000] + "\n\n… (truncated)" + + await update.message.reply_text(answer) + + app.add_handler(CommandHandler("start", start_handler)) + app.add_handler(CommandHandler("help", help_handler)) + app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, message_handler)) + + logger.info("Telegram bot starting (polling mode)") + app.run_polling() diff --git a/src/settings.py b/src/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..f35897dc6cbb93cb3d952daa4456ab2ef9a35aae --- /dev/null +++ b/src/settings.py @@ -0,0 +1,186 @@ +""" +MediGuard AI β€” Pydantic Settings (hierarchical, env-driven) + +All runtime configuration lives here. Values are read from environment +variables (with ``env_nested_delimiter="__"``), so ``OPENSEARCH__HOST`` +maps to ``settings.opensearch.host``. + +Usage:: + + from src.settings import get_settings + settings = get_settings() + print(settings.opensearch.host) +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import List, Literal, Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +class _Base(BaseSettings): + """Shared Settings base with nested-env support.""" + + model_config = { + "env_nested_delimiter": "__", + "frozen": True, + "extra": "ignore", + } + + +# ── Sub-settings ───────────────────────────────────────────────────────────── + +class APISettings(_Base): + host: str = "0.0.0.0" + port: int = 8000 + reload: bool = False + workers: int = 4 + cors_origins: str = "*" + log_level: str = "INFO" + + model_config = {"env_prefix": "API__"} + + +class PostgresSettings(_Base): + database_url: str = "postgresql+psycopg2://mediguard:mediguard@localhost:5432/mediguard_db" + + model_config = {"env_prefix": "POSTGRES__"} + + +class OpenSearchSettings(_Base): + host: str = "http://localhost:9200" + index_name: str = "medical_chunks" + username: str = "" + password: str = "" + verify_certs: bool = False + timeout: int = 30 + + model_config = {"env_prefix": "OPENSEARCH__"} + + +class RedisSettings(_Base): + host: str = "localhost" + port: int = 6379 + db: int = 0 + ttl_seconds: int = 21600 # 6 hours default + enabled: bool = True + + model_config = {"env_prefix": "REDIS__"} + + +class OllamaSettings(_Base): + host: str = "http://localhost:11434" + model: str = "llama3.1:8b" + embedding_model: str = "nomic-embed-text" + timeout: int = 120 + num_ctx: int = 8192 + + model_config = {"env_prefix": "OLLAMA__"} + + +class LLMSettings(_Base): + provider: Literal["groq", "gemini", "ollama"] = "groq" + temperature: float = 0.0 + groq_api_key: str = "" + groq_model: str = "llama-3.3-70b-versatile" + google_api_key: str = "" + gemini_model: str = "gemini-2.0-flash" + + model_config = {"env_prefix": "LLM__"} + + +class EmbeddingSettings(_Base): + provider: Literal["jina", "google", "huggingface", "ollama"] = "google" + jina_api_key: str = "" + jina_model: str = "jina-embeddings-v3" + dimension: int = 1024 + google_api_key: str = "" + huggingface_model: str = "sentence-transformers/all-MiniLM-L6-v2" + batch_size: int = 64 + + model_config = {"env_prefix": "EMBEDDING__"} + + +class ChunkingSettings(_Base): + chunk_size: int = 600 # words + chunk_overlap: int = 100 # words + min_chunk_size: int = 50 + section_aware: bool = True + + model_config = {"env_prefix": "CHUNKING__"} + + +class LangfuseSettings(_Base): + enabled: bool = False + public_key: str = "" + secret_key: str = "" + host: str = "http://localhost:3001" + + model_config = {"env_prefix": "LANGFUSE__"} + + +class TelegramSettings(_Base): + enabled: bool = False + bot_token: str = "" + allowed_users: str = "" # comma-separated user IDs + + model_config = {"env_prefix": "TELEGRAM__"} + + +class BiomarkerSettings(_Base): + reference_file: str = "config/biomarker_references.json" + analyzer_threshold: float = 0.15 + critical_alert_mode: Literal["strict", "moderate", "permissive"] = "strict" + + model_config = {"env_prefix": "BIOMARKER__"} + + +class MedicalPDFSettings(_Base): + pdf_directory: str = "data/medical_pdfs" + vector_store_path: str = "data/vector_stores" + max_file_size_mb: int = 50 + max_pages: int = 500 + + model_config = {"env_prefix": "PDF__"} + + +# ── Root settings ──────────────────────────────────────────────────────────── + +class Settings(_Base): + """Root configuration β€” aggregates all sub-settings.""" + + app_name: str = "MediGuard AI" + app_version: str = "2.0.0" + environment: Literal["development", "staging", "production"] = "development" + debug: bool = False + + # Sub-settings (populated from env with nesting) + api: APISettings = Field(default_factory=APISettings) + postgres: PostgresSettings = Field(default_factory=PostgresSettings) + opensearch: OpenSearchSettings = Field(default_factory=OpenSearchSettings) + redis: RedisSettings = Field(default_factory=RedisSettings) + ollama: OllamaSettings = Field(default_factory=OllamaSettings) + llm: LLMSettings = Field(default_factory=LLMSettings) + embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings) + chunking: ChunkingSettings = Field(default_factory=ChunkingSettings) + langfuse: LangfuseSettings = Field(default_factory=LangfuseSettings) + telegram: TelegramSettings = Field(default_factory=TelegramSettings) + biomarker: BiomarkerSettings = Field(default_factory=BiomarkerSettings) + pdf: MedicalPDFSettings = Field(default_factory=MedicalPDFSettings) + + model_config = { + "env_nested_delimiter": "__", + "frozen": True, + "extra": "ignore", + } + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + """Cached factory β€” returns a single frozen ``Settings`` instance.""" + return Settings() diff --git a/tests/test_agentic_rag.py b/tests/test_agentic_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..30413e293c937daa8bf62be3f58d968faa584de7 --- /dev/null +++ b/tests/test_agentic_rag.py @@ -0,0 +1,202 @@ +""" +Tests for src/services/agents/ β€” agentic RAG pipeline. +""" + +import json +from dataclasses import dataclass +from typing import Any, Optional +from unittest.mock import MagicMock + +import pytest + + +# ----------------------------------------------------------------------- +# Mock context and LLM +# ----------------------------------------------------------------------- + + +class MockMessage: + def __init__(self, content: str): + self.content = content + + +class MockLLM: + """Programmable mock LLM that returns canned responses.""" + + def __init__(self, responses: list[str] | None = None): + self._responses = responses or [] + self._call_count = 0 + + def invoke(self, messages: list) -> MockMessage: + if self._call_count < len(self._responses): + resp = self._responses[self._call_count] + else: + resp = '{"score": 80}' + self._call_count += 1 + return MockMessage(resp) + + +@dataclass +class MockContext: + llm: Any = None + embedding_service: Any = None + opensearch_client: Any = None + cache: Any = None + tracer: Any = None + + +# ----------------------------------------------------------------------- +# Guardrail node +# ----------------------------------------------------------------------- + +class TestGuardrailNode: + def test_in_scope_query(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + ctx = MockContext(llm=MockLLM(['{"score": 85}'])) + state = {"query": "What does high HbA1c mean?"} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is True + assert result["guardrail_score"] == 85.0 + + def test_out_of_scope_query(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + ctx = MockContext(llm=MockLLM(['{"score": 10}'])) + state = {"query": "What is the weather today?"} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is False + assert result["routing_decision"] == "out_of_scope" + + def test_biomarkers_bypass(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + ctx = MockContext(llm=MockLLM()) + state = {"query": "analyze", "biomarkers": {"Glucose": 185}} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is True + assert result["guardrail_score"] == 95.0 + + def test_llm_failure_defaults_in_scope(self): + from src.services.agents.nodes.guardrail_node import guardrail_node + + broken_llm = MagicMock() + broken_llm.invoke.side_effect = Exception("LLM down") + ctx = MockContext(llm=broken_llm) + state = {"query": "What is HbA1c?"} + result = guardrail_node(state, context=ctx) + assert result["is_in_scope"] is True # benefit of the doubt + + +# ----------------------------------------------------------------------- +# Out-of-scope node +# ----------------------------------------------------------------------- + +class TestOutOfScopeNode: + def test_returns_rejection(self): + from src.services.agents.nodes.out_of_scope_node import out_of_scope_node + from src.services.agents.prompts import OUT_OF_SCOPE_RESPONSE + + ctx = MockContext() + result = out_of_scope_node({}, context=ctx) + assert result["final_answer"] == OUT_OF_SCOPE_RESPONSE + + +# ----------------------------------------------------------------------- +# Grade documents node +# ----------------------------------------------------------------------- + +class TestGradeDocumentsNode: + def test_grades_relevant(self): + from src.services.agents.nodes.grade_documents_node import grade_documents_node + + ctx = MockContext(llm=MockLLM(['{"relevant": true}', '{"relevant": false}'])) + state = { + "query": "diabetes treatment", + "retrieved_documents": [ + {"id": "1", "text": "Diabetes is treated with insulin."}, + {"id": "2", "text": "The weather is sunny today."}, + ], + } + result = grade_documents_node(state, context=ctx) + assert len(result["relevant_documents"]) == 1 + assert result["grading_results"][0]["relevant"] is True + assert result["grading_results"][1]["relevant"] is False + + def test_empty_docs_needs_rewrite(self): + from src.services.agents.nodes.grade_documents_node import grade_documents_node + + ctx = MockContext() + state = {"query": "test", "retrieved_documents": []} + result = grade_documents_node(state, context=ctx) + assert result["needs_rewrite"] is True + + +# ----------------------------------------------------------------------- +# Rewrite query node +# ----------------------------------------------------------------------- + +class TestRewriteQueryNode: + def test_rewrites(self): + from src.services.agents.nodes.rewrite_query_node import rewrite_query_node + + ctx = MockContext(llm=MockLLM(["diabetes HbA1c glucose management guidelines"])) + state = {"query": "sugar problems"} + result = rewrite_query_node(state, context=ctx) + assert "diabetes" in result["rewritten_query"].lower() or result["rewritten_query"] + + def test_llm_failure_keeps_original(self): + from src.services.agents.nodes.rewrite_query_node import rewrite_query_node + + broken_llm = MagicMock() + broken_llm.invoke.side_effect = Exception("timeout") + ctx = MockContext(llm=broken_llm) + state = {"query": "original query"} + result = rewrite_query_node(state, context=ctx) + assert result["rewritten_query"] == "original query" + + +# ----------------------------------------------------------------------- +# Generate answer node +# ----------------------------------------------------------------------- + +class TestGenerateAnswerNode: + def test_generates_answer(self): + from src.services.agents.nodes.generate_answer_node import generate_answer_node + + ctx = MockContext(llm=MockLLM(["Based on the evidence, HbA1c of 8.2% indicates poor glycemic control."])) + state = { + "query": "What does HbA1c 8.2 mean?", + "relevant_documents": [ + {"title": "Diabetes Guide", "section": "Diagnosis", "text": "HbA1c above 6.5% indicates diabetes."} + ], + } + result = generate_answer_node(state, context=ctx) + assert "final_answer" in result + assert len(result["final_answer"]) > 10 + + def test_llm_failure_returns_fallback(self): + from src.services.agents.nodes.generate_answer_node import generate_answer_node + + broken_llm = MagicMock() + broken_llm.invoke.side_effect = Exception("dead") + ctx = MockContext(llm=broken_llm) + state = {"query": "test", "relevant_documents": []} + result = generate_answer_node(state, context=ctx) + assert "apologize" in result["final_answer"].lower() + assert len(result["errors"]) > 0 + + +# ----------------------------------------------------------------------- +# Agentic RAG state +# ----------------------------------------------------------------------- + +class TestAgenticRAGState: + def test_state_is_typed_dict(self): + from src.services.agents.state import AgenticRAGState + # Should be usable as a dict type hint + state: AgenticRAGState = { + "query": "test", + "errors": [], + } + assert state["query"] == "test" diff --git a/tests/test_biomarker_service.py b/tests/test_biomarker_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e62d4c0a9564328a7b90a3f76854dea116ab18cf --- /dev/null +++ b/tests/test_biomarker_service.py @@ -0,0 +1,47 @@ +""" +Tests for src/services/biomarker/service.py β€” production biomarker validation. +""" + +import pytest + +from src.services.biomarker.service import BiomarkerService, ValidationReport + + +@pytest.fixture +def service(): + return BiomarkerService() + + +def test_validate_known_biomarkers(service: BiomarkerService): + """Should validate known biomarkers correctly.""" + report = service.validate({"Glucose": 185.0, "HbA1c": 8.2}) + assert isinstance(report, ValidationReport) + assert report.recognized_count >= 1 + # At least one result should exist + assert len(report.results) >= 1 + + +def test_validate_critical_generates_alert(service: BiomarkerService): + """Critically abnormal values should generate safety alerts.""" + # Glucose < 40 or > 500 should be critical + report = service.validate({"Glucose": 550.0}) + if report.recognized_count > 0: + critical = [r for r in report.results if r.status.startswith("CRITICAL")] + # If the validator flags it as critical, there should be alerts + if critical: + assert len(report.safety_alerts) > 0 + + +def test_validate_unrecognized(service: BiomarkerService): + """Unknown biomarker names should be listed as unrecognized.""" + report = service.validate({"FakeMarkerXYZ": 42.0}) + assert "FakeMarkerXYZ" in report.unrecognized + assert report.recognized_count == 0 + + +def test_list_supported(service: BiomarkerService): + """Should return a list of supported biomarkers.""" + supported = service.list_supported() + assert isinstance(supported, list) + # We know the validator has 24 biomarkers + assert len(supported) >= 20 diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..80078aee75f99efcd5042d647ff8fc2ce2a38559 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,27 @@ +""" +Tests for src/services/cache/redis_cache.py β€” graceful degradation. +""" + +import pytest + +from src.services.cache.redis_cache import RedisCache + + +class TestNullCache: + """When Redis is disabled, the NullCache should degrade gracefully.""" + + def test_null_cache_get_returns_none(self): + from src.services.cache.redis_cache import _NullCache + cache = _NullCache() + assert cache.get("anything") is None + + def test_null_cache_set_noop(self): + from src.services.cache.redis_cache import _NullCache + cache = _NullCache() + # Should not raise + cache.set("key", "value", ttl=10) + + def test_null_cache_delete_noop(self): + from src.services.cache.redis_cache import _NullCache + cache = _NullCache() + cache.delete("key") diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..3d09facf0c9c40ff71198557debe84703c214d01 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,43 @@ +""" +Tests for src/exceptions.py β€” domain exception hierarchy. +""" + +import pytest + +from src.exceptions import ( + AnalysisError, + BiomarkerError, + CacheError, + DatabaseError, + EmbeddingError, + GuardrailError, + LLMError, + MediGuardError, + ObservabilityError, + OllamaConnectionError, + OutOfScopeError, + PDFParsingError, + SearchError, + TelegramError, +) + + +def test_all_exceptions_inherit_from_root(): + """Every domain exception should inherit from MediGuardError.""" + for exc_cls in [ + DatabaseError, SearchError, EmbeddingError, PDFParsingError, + LLMError, OllamaConnectionError, BiomarkerError, AnalysisError, + GuardrailError, OutOfScopeError, CacheError, ObservabilityError, + TelegramError, + ]: + assert issubclass(exc_cls, MediGuardError), f"{exc_cls.__name__} must inherit MediGuardError" + + +def test_ollama_inherits_llm(): + assert issubclass(OllamaConnectionError, LLMError) + + +def test_exception_message(): + exc = SearchError("OpenSearch timeout") + assert str(exc) == "OpenSearch timeout" + assert isinstance(exc, MediGuardError) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..baf8089a5736322c4d3552663755247a4594f012 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,29 @@ +""" +Tests for src/models/analysis.py β€” SQLAlchemy ORM models. +""" + +from src.models.analysis import MedicalDocument, PatientAnalysis, SOPVersion + + +def test_patient_analysis_tablename(): + assert PatientAnalysis.__tablename__ == "patient_analyses" + + +def test_medical_document_tablename(): + assert MedicalDocument.__tablename__ == "medical_documents" + + +def test_sop_version_tablename(): + assert SOPVersion.__tablename__ == "sop_versions" + + +def test_patient_analysis_has_columns(): + cols = {c.name for c in PatientAnalysis.__table__.columns} + expected = {"id", "request_id", "biomarkers", "predicted_disease", "created_at"} + assert expected.issubset(cols) + + +def test_medical_document_has_columns(): + cols = {c.name for c in MedicalDocument.__table__.columns} + expected = {"id", "title", "content_hash", "parse_status", "created_at"} + assert expected.issubset(cols) diff --git a/tests/test_opensearch_config.py b/tests/test_opensearch_config.py new file mode 100644 index 0000000000000000000000000000000000000000..38f2eebda85400b6d4e860265dfe284a27ac75e8 --- /dev/null +++ b/tests/test_opensearch_config.py @@ -0,0 +1,32 @@ +""" +Tests for src/services/opensearch/index_config.py β€” OpenSearch mapping. +""" + +from src.services.opensearch.index_config import MEDICAL_CHUNKS_MAPPING + + +def test_mapping_has_required_fields(): + """The mapping should define all required fields.""" + props = MEDICAL_CHUNKS_MAPPING["mappings"]["properties"] + required = ["chunk_text", "title", "embedding", "biomarkers_mentioned", "condition_tags"] + for field_name in required: + assert field_name in props, f"Missing field: {field_name}" + + +def test_knn_vector_config(): + """The embedding field should be configured for KNN.""" + embed = MEDICAL_CHUNKS_MAPPING["mappings"]["properties"]["embedding"] + assert embed["type"] == "knn_vector" + assert embed["dimension"] == 1024 + + +def test_synonym_analyzer(): + """Mapping should include a medical synonym analyzer.""" + analyzers = MEDICAL_CHUNKS_MAPPING["settings"]["analysis"]["analyzer"] + assert "medical_analyzer" in analyzers + + +def test_knn_enabled(): + """KNN should be enabled in settings.""" + settings = MEDICAL_CHUNKS_MAPPING["settings"] + assert settings["index"]["knn"] is True diff --git a/tests/test_pdf_parser.py b/tests/test_pdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..872b055a46d23be3394b8b981803d9f774172ff4 --- /dev/null +++ b/tests/test_pdf_parser.py @@ -0,0 +1,42 @@ +""" +Tests for src/services/pdf_parser/service.py β€” PDF parsing. +""" + +from pathlib import Path + +import pytest + +from src.services.pdf_parser.service import PDFParserService, ParsedDocument + + +@pytest.fixture +def parser(): + return PDFParserService() + + +def test_missing_file(parser: PDFParserService): + """Should return error for missing files.""" + result = parser.parse(Path("/nonexistent/fake.pdf")) + assert isinstance(result, ParsedDocument) + assert result.error is not None + assert "not found" in result.error.lower() + + +def test_parse_directory_empty(parser: PDFParserService, tmp_path: Path): + """Empty directory should return empty list.""" + results = parser.parse_directory(tmp_path) + assert results == [] + + +def test_parse_directory_with_pdf(parser: PDFParserService, tmp_path: Path): + """Should parse PDFs found in a directory.""" + # Check if there are any real PDFs in data/medical_pdfs + pdf_dir = Path("data/medical_pdfs") + if pdf_dir.exists() and list(pdf_dir.glob("*.pdf")): + results = parser.parse_directory(pdf_dir) + assert len(results) > 0 + for doc in results: + assert isinstance(doc, ParsedDocument) + assert doc.filename.endswith(".pdf") + else: + pytest.skip("No medical PDFs available for testing") diff --git a/tests/test_production_api.py b/tests/test_production_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd8a70b892f2031e35e4abadd22f2d1526874f6 --- /dev/null +++ b/tests/test_production_api.py @@ -0,0 +1,85 @@ +""" +Tests for the production FastAPI app (src/main.py) β€” endpoint smoke tests. + +These tests use FastAPI's TestClient with mocked backing services +so they run without Docker infrastructure. +""" + +import pytest +from unittest.mock import MagicMock, patch + +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + """Create a test client with mocked services.""" + # We need to prevent the lifespan from actually connecting to services + with patch("src.main.lifespan") as mock_lifespan: + # Use a no-op lifespan + from contextlib import asynccontextmanager + + @asynccontextmanager + async def _noop_lifespan(app): + import time + app.state.start_time = time.time() + app.state.version = "2.0.0-test" + app.state.opensearch_client = None + app.state.embedding_service = None + app.state.cache = None + app.state.ollama_client = None + app.state.tracer = None + app.state.rag_service = None + app.state.ragbot_service = None + yield + + mock_lifespan.side_effect = _noop_lifespan + + from src.main import create_app + app = create_app() + app.router.lifespan_context = _noop_lifespan + with TestClient(app) as tc: + yield tc + + +def test_root(client: TestClient): + resp = client.get("/") + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "MediGuard AI" + assert "endpoints" in data + + +def test_health(client: TestClient): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "healthy" + assert "version" in data + + +def test_ask_no_service(client: TestClient): + """Without RAG service, /ask should return 503.""" + resp = client.post("/ask", json={"question": "What is diabetes?"}) + assert resp.status_code == 503 + + +def test_search_no_service(client: TestClient): + """Without OpenSearch, /search should return 503.""" + resp = client.post("/search", json={"query": "diabetes", "top_k": 5}) + assert resp.status_code == 503 + + +def test_analyze_no_service(client: TestClient): + """Without RagBot, /analyze/structured should return 503.""" + resp = client.post( + "/analyze/structured", + json={"biomarkers": {"Glucose": 185.0}}, + ) + assert resp.status_code == 503 + + +def test_validation_error(client: TestClient): + """Invalid request should return 422.""" + resp = client.post("/ask", json={"question": "ab"}) # too short + assert resp.status_code == 422 diff --git a/tests/test_prompts.py b/tests/test_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..a146b9c6d95deb78eaee0ec21fcdf7775d581b9f --- /dev/null +++ b/tests/test_prompts.py @@ -0,0 +1,38 @@ +""" +Tests for src/services/agents/prompts.py β€” prompt templates. +""" + +from src.services.agents.prompts import ( + GRADING_SYSTEM, + GUARDRAIL_SYSTEM, + OUT_OF_SCOPE_RESPONSE, + RAG_GENERATION_SYSTEM, + REWRITE_SYSTEM, +) + + +def test_guardrail_prompt_has_score(): + """Guardrail prompt should ask for a 0-100 score.""" + assert "score" in GUARDRAIL_SYSTEM.lower() + assert "0" in GUARDRAIL_SYSTEM + assert "100" in GUARDRAIL_SYSTEM + + +def test_grading_prompt_has_relevant(): + """Grading prompt should ask for relevant true/false.""" + assert "relevant" in GRADING_SYSTEM.lower() + + +def test_rag_generation_has_citation(): + """RAG generation prompt should mention citations.""" + assert "citation" in RAG_GENERATION_SYSTEM.lower() or "cite" in RAG_GENERATION_SYSTEM.lower() + + +def test_out_of_scope_is_polite(): + """Out-of-scope response should be informative and polite.""" + assert "medical" in OUT_OF_SCOPE_RESPONSE.lower() + assert len(OUT_OF_SCOPE_RESPONSE) > 50 + + +def test_rewrite_prompt_exists(): + assert len(REWRITE_SYSTEM) > 50 diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..fee3504b3ca6b31115a0ae7c00b1389905e3dc6a --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,87 @@ +""" +Tests for src/schemas/schemas.py β€” Pydantic request/response models. +""" + +import pytest +from pydantic import ValidationError + +from src.schemas.schemas import ( + AskRequest, + AskResponse, + HealthResponse, + NaturalAnalysisRequest, + SearchRequest, + SearchResponse, + StructuredAnalysisRequest, +) + + +class TestNaturalAnalysisRequest: + def test_valid(self): + req = NaturalAnalysisRequest(message="My glucose is 185 and HbA1c 8.2") + assert req.message == "My glucose is 185 and HbA1c 8.2" + + def test_too_short(self): + with pytest.raises(ValidationError): + NaturalAnalysisRequest(message="hi") + + +class TestStructuredAnalysisRequest: + def test_valid(self): + req = StructuredAnalysisRequest(biomarkers={"Glucose": 185.0}) + assert req.biomarkers["Glucose"] == 185.0 + + def test_empty_biomarkers(self): + with pytest.raises(ValidationError): + StructuredAnalysisRequest(biomarkers={}) + + +class TestAskRequest: + def test_valid(self): + req = AskRequest(question="What does high HbA1c mean?") + assert "HbA1c" in req.question + + def test_too_short(self): + with pytest.raises(ValidationError): + AskRequest(question="ab") + + def test_with_biomarkers(self): + req = AskRequest( + question="Explain my results", + biomarkers={"Glucose": 200.0}, + patient_context="52-year-old male", + ) + assert req.biomarkers is not None + + +class TestSearchRequest: + def test_defaults(self): + req = SearchRequest(query="diabetes guidelines") + assert req.top_k == 10 + assert req.mode == "hybrid" + + +class TestAskResponse: + def test_round_trip(self): + resp = AskResponse( + request_id="req_abc", + question="test?", + answer="test answer", + documents_retrieved=5, + documents_relevant=3, + processing_time_ms=123.4, + ) + data = resp.model_dump() + assert data["status"] == "success" + assert data["documents_relevant"] == 3 + + +class TestHealthResponse: + def test_basic(self): + resp = HealthResponse( + status="healthy", + timestamp="2025-01-01T00:00:00Z", + version="2.0.0", + uptime_seconds=42.0, + ) + assert resp.status == "healthy" diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..44276b66c410e0c50ecebfe796036e8b612233f6 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,44 @@ +""" +Tests for src/settings.py β€” Pydantic Settings hierarchy. +""" + +import os +from unittest.mock import patch + +import pytest + + +def test_settings_defaults(): + """Settings should have sensible defaults without env vars.""" + # Clear any cached instance + from src.settings import get_settings + get_settings.cache_clear() + + settings = get_settings() + assert settings.api.port == 8000 + assert "mediguard" in settings.postgres.database_url + assert "localhost" in settings.opensearch.host + assert settings.redis.port == 6379 + assert settings.ollama.model == "llama3.1:8b" + assert settings.embedding.dimension == 1024 + assert settings.chunking.chunk_size == 600 + + +def test_settings_frozen(): + """Settings should be immutable.""" + from src.settings import get_settings + get_settings.cache_clear() + + settings = get_settings() + with pytest.raises(Exception): + settings.api.port = 9999 + + +def test_settings_singleton(): + """get_settings should return the same cached instance.""" + from src.settings import get_settings + get_settings.cache_clear() + + s1 = get_settings() + s2 = get_settings() + assert s1 is s2 diff --git a/tests/test_text_chunker.py b/tests/test_text_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..6edafb3259da41503893681f32d771aecebad215 --- /dev/null +++ b/tests/test_text_chunker.py @@ -0,0 +1,69 @@ +""" +Tests for src/services/indexing/text_chunker.py β€” medical text chunking. +""" + +import pytest + +from src.services.indexing.text_chunker import MedicalChunk, MedicalTextChunker + + +@pytest.fixture +def chunker(): + return MedicalTextChunker(target_words=30, overlap_words=5, min_words=5) + + +def test_basic_chunking(chunker: MedicalTextChunker): + """Should split text into chunks.""" + # Generate enough words to require multiple chunks (target_words=30) + words = [f"word{i}" for i in range(200)] + text = " ".join(words) + chunks = chunker.chunk_text(text) + assert len(chunks) > 1 + for c in chunks: + assert isinstance(c, MedicalChunk) + assert c.text.strip() + + +def test_section_aware(chunker: MedicalTextChunker): + """Should detect section headers.""" + text = ( + "Introduction\nThis study examines diabetes.\n\n" + "Methods\nWe collected blood samples.\n\n" + "Results\nGlucose levels were elevated." + ) + chunks = chunker.chunk_text(text) + assert len(chunks) >= 1 + + +def test_biomarker_detection(chunker: MedicalTextChunker): + """Should detect biomarkers in chunks.""" + text = ( + "The patient's HbA1c was 8.2% indicating poor glycemic control. " + "Fasting glucose was 185 mg/dL and total cholesterol was elevated at 240." + ) + chunks = chunker.chunk_text(text) + assert len(chunks) >= 1 + # At least one chunk should have biomarkers detected + all_biomarkers = set() + for c in chunks: + all_biomarkers.update(c.biomarkers_mentioned) + assert len(all_biomarkers) > 0 + + +def test_condition_tagging(chunker: MedicalTextChunker): + """Should tag chunks with relevant conditions.""" + text = ( + "Diabetes mellitus is characterised by insulin resistance and elevated blood glucose. " + "Cardiovascular disease risk increases with uncontrolled hypertension." + ) + chunks = chunker.chunk_text(text) + all_tags = set() + for c in chunks: + all_tags.update(c.condition_tags) + assert "diabetes" in all_tags or "heart_disease" in all_tags + + +def test_empty_text(chunker: MedicalTextChunker): + """Empty text should return empty list.""" + assert chunker.chunk_text("") == [] + assert chunker.chunk_text(" ") == []