Refactor environment configuration and remove Vertex AI dependencies
Browse files- .env.example +0 -2
- README.md +1 -15
- notebooks/kaggle_medic_demo.ipynb +2 -30
- src/config.py +6 -41
- src/loader.py +8 -61
.env.example
CHANGED
|
@@ -4,8 +4,6 @@
|
|
| 4 |
|
| 5 |
# ── General ───────────────────────────────────────────────────────────────────
|
| 6 |
MEDIC_ENV=local # local | kaggle | production
|
| 7 |
-
MEDIC_DEFAULT_BACKEND=local # local | vertex
|
| 8 |
-
MEDIC_USE_VERTEX=false
|
| 9 |
MEDIC_QUANTIZATION=4bit # none | 4bit
|
| 10 |
|
| 11 |
# ── Local HuggingFace Models ──────────────────────────────────────────────────
|
|
|
|
| 4 |
|
| 5 |
# ── General ───────────────────────────────────────────────────────────────────
|
| 6 |
MEDIC_ENV=local # local | kaggle | production
|
|
|
|
|
|
|
| 7 |
MEDIC_QUANTIZATION=4bit # none | 4bit
|
| 8 |
|
| 9 |
# ── Local HuggingFace Models ──────────────────────────────────────────────────
|
README.md
CHANGED
|
@@ -47,8 +47,6 @@ Patient form ──► Agent 1: Intake Historian ──► (no lab) ───
|
|
| 47 |
- HuggingFace account with access granted to:
|
| 48 |
- [MedGemma](https://huggingface.co/google/medgemma-4b-it)
|
| 49 |
- [TxGemma](https://huggingface.co/google/txgemma-2b-predict)
|
| 50 |
-
- **For cloud deployment:** Google Cloud project with Vertex AI enabled
|
| 51 |
-
|
| 52 |
---
|
| 53 |
|
| 54 |
## Setup
|
|
@@ -68,9 +66,6 @@ cp .env.example .env
|
|
| 68 |
Edit `.env`. Minimum required settings:
|
| 69 |
|
| 70 |
```bash
|
| 71 |
-
# Choose your backend
|
| 72 |
-
MEDIC_DEFAULT_BACKEND=local # local | vertex
|
| 73 |
-
|
| 74 |
# Local model IDs (HuggingFace)
|
| 75 |
MEDIC_LOCAL_MEDGEMMA_4B_MODEL=google/medgemma-4b-it
|
| 76 |
MEDIC_LOCAL_MEDGEMMA_27B_MODEL=google/medgemma-4b-it # use 4B as fallback if <24 GB VRAM
|
|
@@ -78,15 +73,6 @@ MEDIC_LOCAL_TXGEMMA_9B_MODEL=google/txgemma-2b-predict
|
|
| 78 |
MEDIC_LOCAL_TXGEMMA_2B_MODEL=google/txgemma-2b-predict
|
| 79 |
```
|
| 80 |
|
| 81 |
-
For Vertex AI instead:
|
| 82 |
-
|
| 83 |
-
```bash
|
| 84 |
-
MEDIC_DEFAULT_BACKEND=vertex
|
| 85 |
-
MEDIC_USE_VERTEX=true
|
| 86 |
-
MEDIC_VERTEX_PROJECT_ID=your-gcp-project-id
|
| 87 |
-
MEDIC_VERTEX_LOCATION=us-central1
|
| 88 |
-
```
|
| 89 |
-
|
| 90 |
### 3. Authenticate with HuggingFace
|
| 91 |
|
| 92 |
```bash
|
|
@@ -154,7 +140,7 @@ medic-amr-guard/
|
|
| 154 |
├── src/
|
| 155 |
│ ├── agents.py # Four agent implementations
|
| 156 |
│ ├── graph.py # LangGraph orchestrator + conditional routing
|
| 157 |
-
│ ├── loader.py # Model loading: local HuggingFace
|
| 158 |
│ ├── prompts.py # System and user prompts for all agents
|
| 159 |
│ ├── rag.py # ChromaDB ingestion and retrieval helpers
|
| 160 |
│ ├── state.py # InfectionState TypedDict schema
|
|
|
|
| 47 |
- HuggingFace account with access granted to:
|
| 48 |
- [MedGemma](https://huggingface.co/google/medgemma-4b-it)
|
| 49 |
- [TxGemma](https://huggingface.co/google/txgemma-2b-predict)
|
|
|
|
|
|
|
| 50 |
---
|
| 51 |
|
| 52 |
## Setup
|
|
|
|
| 66 |
Edit `.env`. Minimum required settings:
|
| 67 |
|
| 68 |
```bash
|
|
|
|
|
|
|
|
|
|
| 69 |
# Local model IDs (HuggingFace)
|
| 70 |
MEDIC_LOCAL_MEDGEMMA_4B_MODEL=google/medgemma-4b-it
|
| 71 |
MEDIC_LOCAL_MEDGEMMA_27B_MODEL=google/medgemma-4b-it # use 4B as fallback if <24 GB VRAM
|
|
|
|
| 73 |
MEDIC_LOCAL_TXGEMMA_2B_MODEL=google/txgemma-2b-predict
|
| 74 |
```
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
### 3. Authenticate with HuggingFace
|
| 77 |
|
| 78 |
```bash
|
|
|
|
| 140 |
├── src/
|
| 141 |
│ ├── agents.py # Four agent implementations
|
| 142 |
│ ├── graph.py # LangGraph orchestrator + conditional routing
|
| 143 |
+
│ ├── loader.py # Model loading: local HuggingFace causal LMs
|
| 144 |
│ ├── prompts.py # System and user prompts for all agents
|
| 145 |
│ ├── rag.py # ChromaDB ingestion and retrieval helpers
|
| 146 |
│ ├── state.py # InfectionState TypedDict schema
|
notebooks/kaggle_medic_demo.ipynb
CHANGED
|
@@ -232,35 +232,7 @@
|
|
| 232 |
"id": "a61f1fb1",
|
| 233 |
"metadata": {},
|
| 234 |
"outputs": [],
|
| 235 |
-
"source":
|
| 236 |
-
"# Write .env\n",
|
| 237 |
-
"env = f\"\"\"\n",
|
| 238 |
-
"MEDIC_ENV=kaggle\n",
|
| 239 |
-
"MEDIC_DEFAULT_BACKEND=local\n",
|
| 240 |
-
"MEDIC_USE_VERTEX=false\n",
|
| 241 |
-
"MEDIC_QUANTIZATION=4bit\n",
|
| 242 |
-
"\n",
|
| 243 |
-
"# Agent 1, 2, 4 — MedGemma 4B IT\n",
|
| 244 |
-
"MEDIC_LOCAL_MEDGEMMA_4B_MODEL={MEDGEMMA_4B}\n",
|
| 245 |
-
"\n",
|
| 246 |
-
"# Agent 3 — MedGemma 27B Text IT (subbed with 4B for Kaggle T4)\n",
|
| 247 |
-
"# To use full 27B: set to google/medgemma-27b-text-it\n",
|
| 248 |
-
"MEDIC_LOCAL_MEDGEMMA_27B_MODEL={MEDGEMMA_4B}\n",
|
| 249 |
-
"\n",
|
| 250 |
-
"# Agent 4 safety — TxGemma 9B (subbed with 2B for Kaggle T4)\n",
|
| 251 |
-
"# To use full 9B: set to google/txgemma-9b-predict\n",
|
| 252 |
-
"MEDIC_LOCAL_TXGEMMA_9B_MODEL={TXGEMMA_2B}\n",
|
| 253 |
-
"MEDIC_LOCAL_TXGEMMA_2B_MODEL={TXGEMMA_2B}\n",
|
| 254 |
-
"\n",
|
| 255 |
-
"MEDIC_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2\n",
|
| 256 |
-
"MEDIC_DATA_DIR=/kaggle/working/AMR-Guard/data\n",
|
| 257 |
-
"MEDIC_CHROMA_DB_DIR=/kaggle/working/AMR-Guard/data/chroma_db\n",
|
| 258 |
-
"\"\"\".strip()\n",
|
| 259 |
-
"\n",
|
| 260 |
-
"with open(\"/kaggle/working/AMR-Guard/.env\", \"w\") as f:\n",
|
| 261 |
-
" f.write(env)\n",
|
| 262 |
-
"print(\".env written\")"
|
| 263 |
-
]
|
| 264 |
},
|
| 265 |
{
|
| 266 |
"cell_type": "code",
|
|
@@ -683,4 +655,4 @@
|
|
| 683 |
},
|
| 684 |
"nbformat": 4,
|
| 685 |
"nbformat_minor": 5
|
| 686 |
-
}
|
|
|
|
| 232 |
"id": "a61f1fb1",
|
| 233 |
"metadata": {},
|
| 234 |
"outputs": [],
|
| 235 |
+
"source": "# Write .env\nenv = f\"\"\"\nMEDIC_ENV=kaggle\nMEDIC_QUANTIZATION=4bit\n\n# Agent 1, 2, 4 — MedGemma 4B IT\nMEDIC_LOCAL_MEDGEMMA_4B_MODEL={MEDGEMMA_4B}\n\n# Agent 3 — MedGemma 27B Text IT (subbed with 4B for Kaggle T4)\n# To use full 27B: set to google/medgemma-27b-text-it\nMEDIC_LOCAL_MEDGEMMA_27B_MODEL={MEDGEMMA_4B}\n\n# Agent 4 safety — TxGemma 9B (subbed with 2B for Kaggle T4)\n# To use full 9B: set to google/txgemma-9b-predict\nMEDIC_LOCAL_TXGEMMA_9B_MODEL={TXGEMMA_2B}\nMEDIC_LOCAL_TXGEMMA_2B_MODEL={TXGEMMA_2B}\n\nMEDIC_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2\nMEDIC_DATA_DIR=/kaggle/working/AMR-Guard/data\nMEDIC_CHROMA_DB_DIR=/kaggle/working/AMR-Guard/data/chroma_db\n\"\"\".strip()\n\nwith open(\"/kaggle/working/AMR-Guard/.env\", \"w\") as f:\n f.write(env)\nprint(\".env written\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
},
|
| 237 |
{
|
| 238 |
"cell_type": "code",
|
|
|
|
| 655 |
},
|
| 656 |
"nbformat": 4,
|
| 657 |
"nbformat_minor": 5
|
| 658 |
+
}
|
src/config.py
CHANGED
|
@@ -18,7 +18,6 @@ class Settings(BaseModel):
|
|
| 18 |
All configuration for AMR-Guard, read from environment variables.
|
| 19 |
|
| 20 |
Supports three deployment targets via MEDIC_ENV: local, kaggle, production.
|
| 21 |
-
Backend selection (vertex or local) is controlled by MEDIC_DEFAULT_BACKEND.
|
| 22 |
"""
|
| 23 |
|
| 24 |
environment: Literal["local", "kaggle", "production"] = Field(
|
|
@@ -34,10 +33,7 @@ class Settings(BaseModel):
|
|
| 34 |
default_factory=lambda: Path(os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db"))
|
| 35 |
)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
default_factory=lambda: os.getenv("MEDIC_DEFAULT_BACKEND", "local") # type: ignore[arg-type]
|
| 39 |
-
)
|
| 40 |
-
# 4-bit quantization via bitsandbytes (local backend only)
|
| 41 |
quantization: Literal["none", "4bit"] = Field(
|
| 42 |
default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
|
| 43 |
)
|
|
@@ -45,47 +41,17 @@ class Settings(BaseModel):
|
|
| 45 |
default_factory=lambda: os.getenv("MEDIC_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 46 |
)
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
default_factory=lambda: os.getenv("MEDIC_USE_VERTEX", "true").lower() in {"1", "true", "yes"}
|
| 51 |
-
)
|
| 52 |
-
vertex_project_id: Optional[str] = Field(
|
| 53 |
-
default_factory=lambda: os.getenv("MEDIC_VERTEX_PROJECT_ID")
|
| 54 |
-
)
|
| 55 |
-
vertex_location: str = Field(
|
| 56 |
-
default_factory=lambda: os.getenv("MEDIC_VERTEX_LOCATION", "us-central1")
|
| 57 |
-
)
|
| 58 |
-
vertex_medgemma_4b_model: str = Field(
|
| 59 |
-
default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_4B_MODEL", "med-gemma-4b-it")
|
| 60 |
-
)
|
| 61 |
-
vertex_medgemma_27b_model: str = Field(
|
| 62 |
-
default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_27B_MODEL", "med-gemma-27b-text-it")
|
| 63 |
-
)
|
| 64 |
-
vertex_txgemma_9b_model: str = Field(
|
| 65 |
-
default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_9B_MODEL", "tx-gemma-9b")
|
| 66 |
-
)
|
| 67 |
-
vertex_txgemma_2b_model: str = Field(
|
| 68 |
-
default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_2B_MODEL", "tx-gemma-2b")
|
| 69 |
-
)
|
| 70 |
-
google_application_credentials: Optional[Path] = Field(
|
| 71 |
-
default_factory=lambda: (
|
| 72 |
-
Path(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
|
| 73 |
-
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ
|
| 74 |
-
else None
|
| 75 |
-
)
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
# Local HuggingFace model paths (used when MEDIC_DEFAULT_BACKEND=local)
|
| 79 |
-
local_medgemma_4b_model: Optional[str] = Field(
|
| 80 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
|
| 81 |
)
|
| 82 |
-
|
| 83 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_27B_MODEL")
|
| 84 |
)
|
| 85 |
-
|
| 86 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_9B_MODEL")
|
| 87 |
)
|
| 88 |
-
|
| 89 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_2B_MODEL")
|
| 90 |
)
|
| 91 |
|
|
@@ -94,4 +60,3 @@ class Settings(BaseModel):
|
|
| 94 |
def get_settings() -> Settings:
|
| 95 |
"""Return the cached Settings singleton. Import this instead of instantiating Settings directly."""
|
| 96 |
return Settings()
|
| 97 |
-
|
|
|
|
| 18 |
All configuration for AMR-Guard, read from environment variables.
|
| 19 |
|
| 20 |
Supports three deployment targets via MEDIC_ENV: local, kaggle, production.
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
environment: Literal["local", "kaggle", "production"] = Field(
|
|
|
|
| 33 |
default_factory=lambda: Path(os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db"))
|
| 34 |
)
|
| 35 |
|
| 36 |
+
# 4-bit quantization via bitsandbytes
|
|
|
|
|
|
|
|
|
|
| 37 |
quantization: Literal["none", "4bit"] = Field(
|
| 38 |
default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
|
| 39 |
)
|
|
|
|
| 41 |
default_factory=lambda: os.getenv("MEDIC_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 42 |
)
|
| 43 |
|
| 44 |
+
# Local HuggingFace model paths
|
| 45 |
+
medgemma_4b_model: Optional[str] = Field(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
|
| 47 |
)
|
| 48 |
+
medgemma_27b_model: Optional[str] = Field(
|
| 49 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_27B_MODEL")
|
| 50 |
)
|
| 51 |
+
txgemma_9b_model: Optional[str] = Field(
|
| 52 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_9B_MODEL")
|
| 53 |
)
|
| 54 |
+
txgemma_2b_model: Optional[str] = Field(
|
| 55 |
default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_2B_MODEL")
|
| 56 |
)
|
| 57 |
|
|
|
|
| 60 |
def get_settings() -> Settings:
|
| 61 |
"""Return the cached Settings singleton. Import this instead of instantiating Settings directly."""
|
| 62 |
return Settings()
|
|
|
src/loader.py
CHANGED
|
@@ -7,58 +7,9 @@ from .config import get_settings
|
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
-
TextBackend = Literal["vertex", "local"]
|
| 11 |
TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
|
| 12 |
|
| 13 |
|
| 14 |
-
def _resolve_backend(requested: Optional[TextBackend]) -> TextBackend:
|
| 15 |
-
settings = get_settings()
|
| 16 |
-
backend = requested or settings.default_backend # type: ignore[assignment]
|
| 17 |
-
if backend == "vertex" and not settings.use_vertex:
|
| 18 |
-
logger.info("Vertex disabled in settings, falling back to local backend.")
|
| 19 |
-
return "local"
|
| 20 |
-
return backend
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@lru_cache(maxsize=8)
|
| 24 |
-
def _get_vertex_chat_model(model_name: TextModelName):
|
| 25 |
-
"""Load a Vertex AI chat model and return a callable that takes a prompt string."""
|
| 26 |
-
try:
|
| 27 |
-
from langchain_google_vertexai import ChatVertexAI
|
| 28 |
-
except Exception as exc:
|
| 29 |
-
raise RuntimeError(
|
| 30 |
-
"langchain-google-vertexai is not available; "
|
| 31 |
-
"install it or switch MEDIC_DEFAULT_BACKEND=local."
|
| 32 |
-
) from exc
|
| 33 |
-
|
| 34 |
-
settings = get_settings()
|
| 35 |
-
if settings.vertex_project_id is None:
|
| 36 |
-
raise RuntimeError(
|
| 37 |
-
"MEDIC_VERTEX_PROJECT_ID is not set. "
|
| 38 |
-
"Set it in your environment or .env to use Vertex AI."
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
model_id_map: Dict[TextModelName, str] = {
|
| 42 |
-
"medgemma_4b": settings.vertex_medgemma_4b_model,
|
| 43 |
-
"medgemma_27b": settings.vertex_medgemma_27b_model,
|
| 44 |
-
"txgemma_9b": settings.vertex_txgemma_9b_model,
|
| 45 |
-
"txgemma_2b": settings.vertex_txgemma_2b_model,
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
llm = ChatVertexAI(
|
| 49 |
-
model=model_id_map[model_name],
|
| 50 |
-
project=settings.vertex_project_id,
|
| 51 |
-
location=settings.vertex_location,
|
| 52 |
-
temperature=0.2,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
def _call(prompt: str, **kwargs: Any) -> str:
|
| 56 |
-
result = llm.invoke(prompt, **kwargs)
|
| 57 |
-
return str(getattr(result, "content", result))
|
| 58 |
-
|
| 59 |
-
return _call
|
| 60 |
-
|
| 61 |
-
|
| 62 |
@lru_cache(maxsize=8)
|
| 63 |
def _get_local_causal_lm(model_name: TextModelName):
|
| 64 |
"""Load a local HuggingFace causal LM and return a generation callable."""
|
|
@@ -67,17 +18,17 @@ def _get_local_causal_lm(model_name: TextModelName):
|
|
| 67 |
|
| 68 |
settings = get_settings()
|
| 69 |
model_path_map: Dict[TextModelName, Optional[str]] = {
|
| 70 |
-
"medgemma_4b": settings.
|
| 71 |
-
"medgemma_27b": settings.
|
| 72 |
-
"txgemma_9b": settings.
|
| 73 |
-
"txgemma_2b": settings.
|
| 74 |
}
|
| 75 |
|
| 76 |
model_path = model_path_map[model_name]
|
| 77 |
if not model_path:
|
| 78 |
raise RuntimeError(
|
| 79 |
f"No local model path configured for {model_name}. "
|
| 80 |
-
"Set MEDIC_LOCAL_*_MODEL
|
| 81 |
)
|
| 82 |
|
| 83 |
load_kwargs: Dict[str, Any] = {"device_map": "auto"}
|
|
@@ -108,22 +59,18 @@ def _get_local_causal_lm(model_name: TextModelName):
|
|
| 108 |
@lru_cache(maxsize=32)
|
| 109 |
def get_text_model(
|
| 110 |
model_name: TextModelName = "medgemma_4b",
|
| 111 |
-
backend: Optional[TextBackend] = None,
|
| 112 |
) -> Callable[..., str]:
|
| 113 |
-
"""Return a cached callable for the requested model
|
| 114 |
-
|
| 115 |
-
return _get_vertex_chat_model(model_name) if resolved == "vertex" else _get_local_causal_lm(model_name)
|
| 116 |
|
| 117 |
|
| 118 |
def run_inference(
|
| 119 |
prompt: str,
|
| 120 |
model_name: TextModelName = "medgemma_4b",
|
| 121 |
-
backend: Optional[TextBackend] = None,
|
| 122 |
max_new_tokens: int = 512,
|
| 123 |
temperature: float = 0.2,
|
| 124 |
**kwargs: Any,
|
| 125 |
) -> str:
|
| 126 |
"""Run inference with the specified model. This is the primary entry point for agents."""
|
| 127 |
-
model = get_text_model(model_name=model_name
|
| 128 |
return model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
| 129 |
-
|
|
|
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
|
|
|
| 10 |
TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
@lru_cache(maxsize=8)
|
| 14 |
def _get_local_causal_lm(model_name: TextModelName):
|
| 15 |
"""Load a local HuggingFace causal LM and return a generation callable."""
|
|
|
|
| 18 |
|
| 19 |
settings = get_settings()
|
| 20 |
model_path_map: Dict[TextModelName, Optional[str]] = {
|
| 21 |
+
"medgemma_4b": settings.medgemma_4b_model,
|
| 22 |
+
"medgemma_27b": settings.medgemma_27b_model,
|
| 23 |
+
"txgemma_9b": settings.txgemma_9b_model,
|
| 24 |
+
"txgemma_2b": settings.txgemma_2b_model,
|
| 25 |
}
|
| 26 |
|
| 27 |
model_path = model_path_map[model_name]
|
| 28 |
if not model_path:
|
| 29 |
raise RuntimeError(
|
| 30 |
f"No local model path configured for {model_name}. "
|
| 31 |
+
f"Set MEDIC_LOCAL_*_MODEL in your environment or .env."
|
| 32 |
)
|
| 33 |
|
| 34 |
load_kwargs: Dict[str, Any] = {"device_map": "auto"}
|
|
|
|
| 59 |
@lru_cache(maxsize=32)
|
| 60 |
def get_text_model(
|
| 61 |
model_name: TextModelName = "medgemma_4b",
|
|
|
|
| 62 |
) -> Callable[..., str]:
|
| 63 |
+
"""Return a cached callable for the requested model."""
|
| 64 |
+
return _get_local_causal_lm(model_name)
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def run_inference(
|
| 68 |
prompt: str,
|
| 69 |
model_name: TextModelName = "medgemma_4b",
|
|
|
|
| 70 |
max_new_tokens: int = 512,
|
| 71 |
temperature: float = 0.2,
|
| 72 |
**kwargs: Any,
|
| 73 |
) -> str:
|
| 74 |
"""Run inference with the specified model. This is the primary entry point for agents."""
|
| 75 |
+
model = get_text_model(model_name=model_name)
|
| 76 |
return model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
|
|