temporary committing need to be check
Browse files- .python-version +1 -0
- README.md +0 -0
- main.py +6 -0
- pyproject.toml +28 -0
- scripts/ingest_data.py +40 -0
- src/Med_I_C.egg-info/PKG-INFO +24 -0
- src/Med_I_C.egg-info/SOURCES.txt +17 -0
- src/Med_I_C.egg-info/dependency_links.txt +1 -0
- src/Med_I_C.egg-info/requires.txt +18 -0
- src/Med_I_C.egg-info/top_level.txt +9 -0
- src/config.py +150 -0
- src/loader.py +201 -0
- src/state.py +125 -0
- uv.lock +0 -0
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.13
|
README.md
ADDED
|
File without changes
|
main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from med-i-c!")
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "Med-I-C"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"langgraph>=0.0.15",
|
| 9 |
+
"langchain>=0.3.0",
|
| 10 |
+
"langchain-text-splitters",
|
| 11 |
+
"langchain-google-vertexai",
|
| 12 |
+
"google-cloud-aiplatform",
|
| 13 |
+
"chromadb>=0.4.0",
|
| 14 |
+
"sentence-transformers",
|
| 15 |
+
"transformers>=4.50.0",
|
| 16 |
+
"torch",
|
| 17 |
+
"accelerate",
|
| 18 |
+
"bitsandbytes",
|
| 19 |
+
"streamlit",
|
| 20 |
+
"pillow",
|
| 21 |
+
"pydantic>=2.0",
|
| 22 |
+
"python-dotenv",
|
| 23 |
+
"openpyxl",
|
| 24 |
+
"requests",
|
| 25 |
+
"pypdf",
|
| 26 |
+
"langchain-community>=0.4.1",
|
| 27 |
+
"jq>=1.11.0",
|
| 28 |
+
]
|
scripts/ingest_data.py
CHANGED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import chromadb
|
| 3 |
+
from chromadb.utils import embedding_functions
|
| 4 |
+
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, JSONLoader
|
| 5 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 6 |
+
|
| 7 |
+
# 1. Setup Chroma Persistence
|
| 8 |
+
CHROMA_PATH = "data/chroma_db"
|
| 9 |
+
DATA_PATH = "data/Med-I-C/raw"
|
| 10 |
+
|
| 11 |
+
def ingest_medical_data():
|
| 12 |
+
# Persistent client for the competition (Kaggle/Local)
|
| 13 |
+
client = chromadb.PersistentClient(path=CHROMA_PATH)
|
| 14 |
+
|
| 15 |
+
# Using the embedding model you specified
|
| 16 |
+
model_name = "all-MiniLM-L6-v2"
|
| 17 |
+
ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
|
| 18 |
+
|
| 19 |
+
# 2. Ingest Guidelines (PDFs)
|
| 20 |
+
# We create a specific collection for cleaner retrieval
|
| 21 |
+
guideline_col = client.get_or_create_collection(name="antibiotic_guidelines", embedding_function=ef)
|
| 22 |
+
|
| 23 |
+
loader = DirectoryLoader(f"{DATA_PATH}/guidelines", glob="*.pdf", loader_cls=PyPDFLoader)
|
| 24 |
+
documents = loader.load()
|
| 25 |
+
|
| 26 |
+
# 1000/100 split as discussed for clinical coherence
|
| 27 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
| 28 |
+
chunks = text_splitter.split_documents(documents)
|
| 29 |
+
|
| 30 |
+
# Adding to Chroma
|
| 31 |
+
guideline_col.add(
|
| 32 |
+
ids=[f"guideline_{i}" for i in range(len(chunks))],
|
| 33 |
+
documents=[c.page_content for c in chunks],
|
| 34 |
+
metadatas=[c.metadata for c in chunks]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
print(f"Successfully ingested {len(chunks)} guideline chunks.")
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
ingest_medical_data()
|
src/Med_I_C.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: Med-I-C
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Add your description here
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Description-Content-Type: text/markdown
|
| 7 |
+
Requires-Dist: langgraph>=0.0.15
|
| 8 |
+
Requires-Dist: langchain>=0.3.0
|
| 9 |
+
Requires-Dist: langchain-text-splitters
|
| 10 |
+
Requires-Dist: langchain-google-vertexai
|
| 11 |
+
Requires-Dist: google-cloud-aiplatform
|
| 12 |
+
Requires-Dist: chromadb>=0.4.0
|
| 13 |
+
Requires-Dist: sentence-transformers
|
| 14 |
+
Requires-Dist: transformers>=4.50.0
|
| 15 |
+
Requires-Dist: torch
|
| 16 |
+
Requires-Dist: accelerate
|
| 17 |
+
Requires-Dist: bitsandbytes
|
| 18 |
+
Requires-Dist: streamlit
|
| 19 |
+
Requires-Dist: pillow
|
| 20 |
+
Requires-Dist: pydantic>=2.0
|
| 21 |
+
Requires-Dist: python-dotenv
|
| 22 |
+
Requires-Dist: openpyxl
|
| 23 |
+
Requires-Dist: requests
|
| 24 |
+
Requires-Dist: pypdf
|
src/Med_I_C.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
src/__init__.py
|
| 4 |
+
src/agents.py
|
| 5 |
+
src/config.py
|
| 6 |
+
src/graph.py
|
| 7 |
+
src/loader.py
|
| 8 |
+
src/prompts.py
|
| 9 |
+
src/rag.py
|
| 10 |
+
src/state.py
|
| 11 |
+
src/utils.py
|
| 12 |
+
src/Med_I_C.egg-info/PKG-INFO
|
| 13 |
+
src/Med_I_C.egg-info/SOURCES.txt
|
| 14 |
+
src/Med_I_C.egg-info/dependency_links.txt
|
| 15 |
+
src/Med_I_C.egg-info/requires.txt
|
| 16 |
+
src/Med_I_C.egg-info/top_level.txt
|
| 17 |
+
tests/test_pipeline.py
|
src/Med_I_C.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/Med_I_C.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langgraph>=0.0.15
|
| 2 |
+
langchain>=0.3.0
|
| 3 |
+
langchain-text-splitters
|
| 4 |
+
langchain-google-vertexai
|
| 5 |
+
google-cloud-aiplatform
|
| 6 |
+
chromadb>=0.4.0
|
| 7 |
+
sentence-transformers
|
| 8 |
+
transformers>=4.50.0
|
| 9 |
+
torch
|
| 10 |
+
accelerate
|
| 11 |
+
bitsandbytes
|
| 12 |
+
streamlit
|
| 13 |
+
pillow
|
| 14 |
+
pydantic>=2.0
|
| 15 |
+
python-dotenv
|
| 16 |
+
openpyxl
|
| 17 |
+
requests
|
| 18 |
+
pypdf
|
src/Med_I_C.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__init__
|
| 2 |
+
agents
|
| 3 |
+
config
|
| 4 |
+
graph
|
| 5 |
+
loader
|
| 6 |
+
prompts
|
| 7 |
+
rag
|
| 8 |
+
state
|
| 9 |
+
utils
|
src/config.py
CHANGED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Literal, Optional
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Load variables from a local .env if present (handy for local dev)
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Settings(BaseModel):
|
| 18 |
+
"""
|
| 19 |
+
Central configuration object for Med-I-C.
|
| 20 |
+
|
| 21 |
+
Values are read from environment variables where possible so that
|
| 22 |
+
the same code can run locally, on Kaggle, and in production.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# ------------------------------------------------------------------
|
| 26 |
+
# General environment
|
| 27 |
+
# ------------------------------------------------------------------
|
| 28 |
+
environment: Literal["local", "kaggle", "production"] = Field(
|
| 29 |
+
default_factory=lambda: os.getenv("MEDIC_ENV", "local")
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
project_root: Path = Field(
|
| 33 |
+
default_factory=lambda: Path(__file__).resolve().parents[1]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
data_dir: Path = Field(
|
| 37 |
+
default_factory=lambda: Path(
|
| 38 |
+
os.getenv("MEDIC_DATA_DIR", "data")
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
chroma_db_dir: Path = Field(
|
| 43 |
+
default_factory=lambda: Path(
|
| 44 |
+
os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db")
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# ------------------------------------------------------------------
|
| 49 |
+
# Model + deployment preferences
|
| 50 |
+
# ------------------------------------------------------------------
|
| 51 |
+
default_backend: Literal["vertex", "local"] = Field(
|
| 52 |
+
default_factory=lambda: os.getenv("MEDIC_DEFAULT_BACKEND", "vertex") # type: ignore[arg-type]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Quantization mode for local models
|
| 56 |
+
quantization: Literal["none", "4bit"] = Field(
|
| 57 |
+
default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Embedding model used for ChromaDB / RAG
|
| 61 |
+
embedding_model_name: str = Field(
|
| 62 |
+
default_factory=lambda: os.getenv(
|
| 63 |
+
"MEDIC_EMBEDDING_MODEL",
|
| 64 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# ------------------------------------------------------------------
|
| 69 |
+
# Vertex AI configuration (MedGemma / TxGemma hosted on Vertex)
|
| 70 |
+
# ------------------------------------------------------------------
|
| 71 |
+
use_vertex: bool = Field(
|
| 72 |
+
default_factory=lambda: os.getenv("MEDIC_USE_VERTEX", "true").lower()
|
| 73 |
+
in {"1", "true", "yes"}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
vertex_project_id: Optional[str] = Field(
|
| 77 |
+
default_factory=lambda: os.getenv("MEDIC_VERTEX_PROJECT_ID")
|
| 78 |
+
)
|
| 79 |
+
vertex_location: str = Field(
|
| 80 |
+
default_factory=lambda: os.getenv("MEDIC_VERTEX_LOCATION", "us-central1")
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Model IDs as expected by Vertex / langchain-google-vertexai
|
| 84 |
+
vertex_medgemma_4b_model: str = Field(
|
| 85 |
+
default_factory=lambda: os.getenv(
|
| 86 |
+
"MEDIC_VERTEX_MEDGEMMA_4B_MODEL",
|
| 87 |
+
"med-gemma-4b-it",
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
vertex_medgemma_27b_model: str = Field(
|
| 91 |
+
default_factory=lambda: os.getenv(
|
| 92 |
+
"MEDIC_VERTEX_MEDGEMMA_27B_MODEL",
|
| 93 |
+
"med-gemma-27b-text-it",
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
vertex_txgemma_9b_model: str = Field(
|
| 97 |
+
default_factory=lambda: os.getenv(
|
| 98 |
+
"MEDIC_VERTEX_TXGEMMA_9B_MODEL",
|
| 99 |
+
"tx-gemma-9b",
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
vertex_txgemma_2b_model: str = Field(
|
| 103 |
+
default_factory=lambda: os.getenv(
|
| 104 |
+
"MEDIC_VERTEX_TXGEMMA_2B_MODEL",
|
| 105 |
+
"tx-gemma-2b",
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Standard GOOGLE_APPLICATION_CREDENTIALS path, if needed
|
| 110 |
+
google_application_credentials: Optional[Path] = Field(
|
| 111 |
+
default_factory=lambda: (
|
| 112 |
+
Path(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
|
| 113 |
+
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ
|
| 114 |
+
else None
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# ------------------------------------------------------------------
|
| 119 |
+
# Local model paths (for offline / Kaggle GPU usage)
|
| 120 |
+
# ------------------------------------------------------------------
|
| 121 |
+
local_medgemma_4b_model: Optional[str] = Field(
|
| 122 |
+
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
|
| 123 |
+
)
|
| 124 |
+
local_medgemma_27b_model: Optional[str] = Field(
|
| 125 |
+
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_27B_MODEL")
|
| 126 |
+
)
|
| 127 |
+
local_txgemma_9b_model: Optional[str] = Field(
|
| 128 |
+
default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_9B_MODEL")
|
| 129 |
+
)
|
| 130 |
+
local_txgemma_2b_model: Optional[str] = Field(
|
| 131 |
+
default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_2B_MODEL")
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@lru_cache(maxsize=1)
|
| 136 |
+
def get_settings() -> Settings:
|
| 137 |
+
"""
|
| 138 |
+
Return a cached Settings instance.
|
| 139 |
+
|
| 140 |
+
Use this helper everywhere instead of instantiating Settings directly:
|
| 141 |
+
|
| 142 |
+
from src.config import get_settings
|
| 143 |
+
settings = get_settings()
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
return Settings()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
__all__ = ["Settings", "get_settings"]
|
| 150 |
+
|
src/loader.py
CHANGED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from typing import Any, Callable, Dict, Literal, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from .config import get_settings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
TextBackend = Literal["vertex", "local"]
|
| 14 |
+
TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _resolve_backend(
|
| 18 |
+
requested: Optional[TextBackend],
|
| 19 |
+
) -> TextBackend:
|
| 20 |
+
settings = get_settings()
|
| 21 |
+
backend = requested or settings.default_backend # type: ignore[assignment]
|
| 22 |
+
if backend == "vertex" and not settings.use_vertex:
|
| 23 |
+
logger.info("Vertex disabled in settings, falling back to local backend.")
|
| 24 |
+
return "local"
|
| 25 |
+
return backend
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@lru_cache(maxsize=8)
|
| 29 |
+
def _get_vertex_chat_model(model_name: TextModelName):
|
| 30 |
+
"""
|
| 31 |
+
Lazily construct a Vertex AI chat model via langchain-google-vertexai.
|
| 32 |
+
|
| 33 |
+
Returns an object with an .invoke(str) method; we wrap this in a simple
|
| 34 |
+
callable for downstream use.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from langchain_google_vertexai import ChatVertexAI
|
| 39 |
+
except Exception as exc: # pragma: no cover - import-time failure
|
| 40 |
+
raise RuntimeError(
|
| 41 |
+
"langchain-google-vertexai is not available; "
|
| 42 |
+
"install it or switch MEDIC_DEFAULT_BACKEND=local."
|
| 43 |
+
) from exc
|
| 44 |
+
|
| 45 |
+
settings = get_settings()
|
| 46 |
+
|
| 47 |
+
if settings.vertex_project_id is None:
|
| 48 |
+
raise RuntimeError(
|
| 49 |
+
"MEDIC_VERTEX_PROJECT_ID is not set. "
|
| 50 |
+
"Set it in your environment or .env to use Vertex AI."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
model_id_map: Dict[TextModelName, str] = {
|
| 54 |
+
"medgemma_4b": settings.vertex_medgemma_4b_model,
|
| 55 |
+
"medgemma_27b": settings.vertex_medgemma_27b_model,
|
| 56 |
+
"txgemma_9b": settings.vertex_txgemma_9b_model,
|
| 57 |
+
"txgemma_2b": settings.vertex_txgemma_2b_model,
|
| 58 |
+
}
|
| 59 |
+
model_id = model_id_map[model_name]
|
| 60 |
+
|
| 61 |
+
llm = ChatVertexAI(
|
| 62 |
+
model=model_id,
|
| 63 |
+
project=settings.vertex_project_id,
|
| 64 |
+
location=settings.vertex_location,
|
| 65 |
+
temperature=0.2,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def _call(prompt: str, **kwargs: Any) -> str:
|
| 69 |
+
"""Thin wrapper returning plain text from ChatVertexAI."""
|
| 70 |
+
|
| 71 |
+
result = llm.invoke(prompt, **kwargs)
|
| 72 |
+
# langchain BaseMessage or plain string
|
| 73 |
+
content = getattr(result, "content", result)
|
| 74 |
+
return str(content)
|
| 75 |
+
|
| 76 |
+
return _call
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@lru_cache(maxsize=8)
|
| 80 |
+
def _get_local_causal_lm(model_name: TextModelName):
|
| 81 |
+
"""
|
| 82 |
+
Lazily load a local transformers model for offline / Kaggle usage.
|
| 83 |
+
|
| 84 |
+
Assumes model paths are provided via MEDIC_LOCAL_* env vars and that
|
| 85 |
+
the appropriate model weights are available in the environment.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 89 |
+
import torch
|
| 90 |
+
|
| 91 |
+
settings = get_settings()
|
| 92 |
+
|
| 93 |
+
model_path_map: Dict[TextModelName, Optional[str]] = {
|
| 94 |
+
"medgemma_4b": settings.local_medgemma_4b_model,
|
| 95 |
+
"medgemma_27b": settings.local_medgemma_27b_model,
|
| 96 |
+
"txgemma_9b": settings.local_txgemma_9b_model,
|
| 97 |
+
"txgemma_2b": settings.local_txgemma_2b_model,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
model_path = model_path_map[model_name]
|
| 101 |
+
if not model_path:
|
| 102 |
+
raise RuntimeError(
|
| 103 |
+
f"No local model path configured for {model_name}. "
|
| 104 |
+
f"Set MEDIC_LOCAL_*_MODEL or use the Vertex backend."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
load_kwargs: Dict[str, Any] = {
|
| 108 |
+
"device_map": "auto",
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# Optional 4-bit quantization via bitsandbytes
|
| 112 |
+
if get_settings().quantization == "4bit":
|
| 113 |
+
load_kwargs["load_in_4bit"] = True
|
| 114 |
+
|
| 115 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 116 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
|
| 117 |
+
|
| 118 |
+
def _call(
|
| 119 |
+
prompt: str,
|
| 120 |
+
max_new_tokens: int = 512,
|
| 121 |
+
temperature: float = 0.2,
|
| 122 |
+
**generate_kwargs: Any,
|
| 123 |
+
) -> str:
|
| 124 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 125 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 126 |
+
|
| 127 |
+
do_sample = temperature > 0
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
output_ids = model.generate(
|
| 131 |
+
**inputs,
|
| 132 |
+
do_sample=do_sample,
|
| 133 |
+
temperature=temperature if do_sample else 0.0,
|
| 134 |
+
max_new_tokens=max_new_tokens,
|
| 135 |
+
**generate_kwargs,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Drop the prompt tokens and decode only the completion
|
| 139 |
+
generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
|
| 140 |
+
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 141 |
+
return text.strip()
|
| 142 |
+
|
| 143 |
+
return _call
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@lru_cache(maxsize=32)
|
| 147 |
+
def get_text_model(
|
| 148 |
+
model_name: TextModelName = "medgemma_4b",
|
| 149 |
+
backend: Optional[TextBackend] = None,
|
| 150 |
+
) -> Callable[..., str]:
|
| 151 |
+
"""
|
| 152 |
+
Return a cached text-generation callable.
|
| 153 |
+
|
| 154 |
+
Example:
|
| 155 |
+
|
| 156 |
+
from src.loader import get_text_model
|
| 157 |
+
model = get_text_model("medgemma_4b")
|
| 158 |
+
answer = model("Explain ESBL in simple terms.")
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
resolved_backend = _resolve_backend(backend)
|
| 162 |
+
|
| 163 |
+
if resolved_backend == "vertex":
|
| 164 |
+
return _get_vertex_chat_model(model_name)
|
| 165 |
+
else:
|
| 166 |
+
return _get_local_causal_lm(model_name)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def run_inference(
|
| 170 |
+
prompt: str,
|
| 171 |
+
model_name: TextModelName = "medgemma_4b",
|
| 172 |
+
backend: Optional[TextBackend] = None,
|
| 173 |
+
max_new_tokens: int = 512,
|
| 174 |
+
temperature: float = 0.2,
|
| 175 |
+
**kwargs: Any,
|
| 176 |
+
) -> str:
|
| 177 |
+
"""
|
| 178 |
+
Convenience wrapper around `get_text_model`.
|
| 179 |
+
|
| 180 |
+
This is the simplest entry point to use inside agents:
|
| 181 |
+
|
| 182 |
+
from src.loader import run_inference
|
| 183 |
+
text = run_inference(prompt, model_name="medgemma_4b")
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
model = get_text_model(model_name=model_name, backend=backend)
|
| 187 |
+
return model(
|
| 188 |
+
prompt,
|
| 189 |
+
max_new_tokens=max_new_tokens,
|
| 190 |
+
temperature=temperature,
|
| 191 |
+
**kwargs,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
__all__ = [
|
| 196 |
+
"TextBackend",
|
| 197 |
+
"TextModelName",
|
| 198 |
+
"get_text_model",
|
| 199 |
+
"run_inference",
|
| 200 |
+
]
|
| 201 |
+
|
src/state.py
CHANGED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Dict, List, Literal, NotRequired, Optional, TypedDict
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LabResult(TypedDict, total=False):
|
| 8 |
+
"""Structured representation of a single lab value."""
|
| 9 |
+
|
| 10 |
+
name: str
|
| 11 |
+
value: str
|
| 12 |
+
unit: NotRequired[Optional[str]]
|
| 13 |
+
reference_range: NotRequired[Optional[str]]
|
| 14 |
+
flag: NotRequired[Optional[Literal["low", "normal", "high", "critical"]]]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MICDatum(TypedDict, total=False):
|
| 18 |
+
"""Single MIC measurement for a bug–drug pair."""
|
| 19 |
+
|
| 20 |
+
organism: str
|
| 21 |
+
antibiotic: str
|
| 22 |
+
mic_value: str
|
| 23 |
+
mic_unit: NotRequired[Optional[str]]
|
| 24 |
+
interpretation: NotRequired[Optional[Literal["S", "I", "R"]]]
|
| 25 |
+
breakpoint_source: NotRequired[Optional[str]] # e.g. EUCAST v16.0
|
| 26 |
+
year: NotRequired[Optional[int]]
|
| 27 |
+
site: NotRequired[Optional[str]] # e.g. blood, urine
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Recommendation(TypedDict, total=False):
|
| 31 |
+
"""Final clinical recommendation assembled by Agent 4."""
|
| 32 |
+
|
| 33 |
+
primary_antibiotic: Optional[str]
|
| 34 |
+
backup_antibiotic: NotRequired[Optional[str]]
|
| 35 |
+
dose: Optional[str]
|
| 36 |
+
route: Optional[str]
|
| 37 |
+
frequency: Optional[str]
|
| 38 |
+
duration: Optional[str]
|
| 39 |
+
rationale: Optional[str]
|
| 40 |
+
references: NotRequired[List[str]]
|
| 41 |
+
safety_alerts: NotRequired[List[str]]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class InfectionState(TypedDict, total=False):
|
| 45 |
+
"""
|
| 46 |
+
Global LangGraph state for the Med-I-C pipeline.
|
| 47 |
+
|
| 48 |
+
All agents read from and write back to this object.
|
| 49 |
+
Most keys are optional to keep the schema flexible across stages.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# ------------------------------------------------------------------
|
| 53 |
+
# Patient identity & demographics
|
| 54 |
+
# ------------------------------------------------------------------
|
| 55 |
+
patient_id: NotRequired[Optional[str]]
|
| 56 |
+
age_years: NotRequired[Optional[float]]
|
| 57 |
+
sex: NotRequired[Optional[Literal["male", "female", "other", "unknown"]]]
|
| 58 |
+
weight_kg: NotRequired[Optional[float]]
|
| 59 |
+
height_cm: NotRequired[Optional[float]]
|
| 60 |
+
|
| 61 |
+
# ------------------------------------------------------------------
|
| 62 |
+
# Clinical context
|
| 63 |
+
# ------------------------------------------------------------------
|
| 64 |
+
suspected_source: NotRequired[Optional[str]] # e.g. "community UTI"
|
| 65 |
+
comorbidities: NotRequired[List[str]]
|
| 66 |
+
medications: NotRequired[List[str]]
|
| 67 |
+
allergies: NotRequired[List[str]]
|
| 68 |
+
infection_site: NotRequired[Optional[str]]
|
| 69 |
+
country_or_region: NotRequired[Optional[str]]
|
| 70 |
+
|
| 71 |
+
# ------------------------------------------------------------------
|
| 72 |
+
# Renal function / vitals
|
| 73 |
+
# ------------------------------------------------------------------
|
| 74 |
+
serum_creatinine_mg_dl: NotRequired[Optional[float]]
|
| 75 |
+
creatinine_clearance_ml_min: NotRequired[Optional[float]]
|
| 76 |
+
vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
|
| 77 |
+
|
| 78 |
+
# ------------------------------------------------------------------
|
| 79 |
+
# Lab data & MICs
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
labs_raw_text: NotRequired[Optional[str]] # raw OCR / PDF text
|
| 82 |
+
labs_parsed: NotRequired[List[LabResult]]
|
| 83 |
+
|
| 84 |
+
mic_data: NotRequired[List[MICDatum]]
|
| 85 |
+
mic_trend_summary: NotRequired[Optional[str]]
|
| 86 |
+
|
| 87 |
+
# ------------------------------------------------------------------
|
| 88 |
+
# Stage / routing metadata
|
| 89 |
+
# ------------------------------------------------------------------
|
| 90 |
+
stage: NotRequired[Literal["empirical", "targeted"]]
|
| 91 |
+
route_to_vision: NotRequired[bool]
|
| 92 |
+
route_to_trend_analyst: NotRequired[bool]
|
| 93 |
+
|
| 94 |
+
# ------------------------------------------------------------------
|
| 95 |
+
# Agent outputs
|
| 96 |
+
# ------------------------------------------------------------------
|
| 97 |
+
intake_notes: NotRequired[Optional[str]] # Agent 1
|
| 98 |
+
vision_notes: NotRequired[Optional[str]] # Agent 2
|
| 99 |
+
trend_notes: NotRequired[Optional[str]] # Agent 3
|
| 100 |
+
pharmacology_notes: NotRequired[Optional[str]] # Agent 4
|
| 101 |
+
|
| 102 |
+
recommendation: NotRequired[Optional[Recommendation]]
|
| 103 |
+
|
| 104 |
+
# ------------------------------------------------------------------
|
| 105 |
+
# RAG / context + safety
|
| 106 |
+
# ------------------------------------------------------------------
|
| 107 |
+
rag_context: NotRequired[Optional[str]]
|
| 108 |
+
guideline_sources: NotRequired[List[str]]
|
| 109 |
+
breakpoint_sources: NotRequired[List[str]]
|
| 110 |
+
safety_warnings: NotRequired[List[str]]
|
| 111 |
+
|
| 112 |
+
# ------------------------------------------------------------------
|
| 113 |
+
# Diagnostics / debugging
|
| 114 |
+
# ------------------------------------------------------------------
|
| 115 |
+
errors: NotRequired[List[str]]
|
| 116 |
+
debug_log: NotRequired[List[str]]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
__all__ = [
|
| 120 |
+
"LabResult",
|
| 121 |
+
"MICDatum",
|
| 122 |
+
"Recommendation",
|
| 123 |
+
"InfectionState",
|
| 124 |
+
]
|
| 125 |
+
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|