ghitaben commited on
Commit
936bc6b
·
1 Parent(s): 2cec50c

Refactor environment configuration and remove Vertex AI dependencies

Browse files
Files changed (5) hide show
  1. .env.example +0 -2
  2. README.md +1 -15
  3. notebooks/kaggle_medic_demo.ipynb +2 -30
  4. src/config.py +6 -41
  5. src/loader.py +8 -61
.env.example CHANGED
@@ -4,8 +4,6 @@
4
 
5
  # ── General ───────────────────────────────────────────────────────────────────
6
  MEDIC_ENV=local # local | kaggle | production
7
- MEDIC_DEFAULT_BACKEND=local # local | vertex
8
- MEDIC_USE_VERTEX=false
9
  MEDIC_QUANTIZATION=4bit # none | 4bit
10
 
11
  # ── Local HuggingFace Models ──────────────────────────────────────────────────
 
4
 
5
  # ── General ───────────────────────────────────────────────────────────────────
6
  MEDIC_ENV=local # local | kaggle | production
 
 
7
  MEDIC_QUANTIZATION=4bit # none | 4bit
8
 
9
  # ── Local HuggingFace Models ──────────────────────────────────────────────────
README.md CHANGED
@@ -47,8 +47,6 @@ Patient form ──► Agent 1: Intake Historian ──► (no lab) ───
47
  - HuggingFace account with access granted to:
48
  - [MedGemma](https://huggingface.co/google/medgemma-4b-it)
49
  - [TxGemma](https://huggingface.co/google/txgemma-2b-predict)
50
- - **For cloud deployment:** Google Cloud project with Vertex AI enabled
51
-
52
  ---
53
 
54
  ## Setup
@@ -68,9 +66,6 @@ cp .env.example .env
68
  Edit `.env`. Minimum required settings:
69
 
70
  ```bash
71
- # Choose your backend
72
- MEDIC_DEFAULT_BACKEND=local # local | vertex
73
-
74
  # Local model IDs (HuggingFace)
75
  MEDIC_LOCAL_MEDGEMMA_4B_MODEL=google/medgemma-4b-it
76
  MEDIC_LOCAL_MEDGEMMA_27B_MODEL=google/medgemma-4b-it # use 4B as fallback if <24 GB VRAM
@@ -78,15 +73,6 @@ MEDIC_LOCAL_TXGEMMA_9B_MODEL=google/txgemma-2b-predict
78
  MEDIC_LOCAL_TXGEMMA_2B_MODEL=google/txgemma-2b-predict
79
  ```
80
 
81
- For Vertex AI instead:
82
-
83
- ```bash
84
- MEDIC_DEFAULT_BACKEND=vertex
85
- MEDIC_USE_VERTEX=true
86
- MEDIC_VERTEX_PROJECT_ID=your-gcp-project-id
87
- MEDIC_VERTEX_LOCATION=us-central1
88
- ```
89
-
90
  ### 3. Authenticate with HuggingFace
91
 
92
  ```bash
@@ -154,7 +140,7 @@ medic-amr-guard/
154
  ├── src/
155
  │ ├── agents.py # Four agent implementations
156
  │ ├── graph.py # LangGraph orchestrator + conditional routing
157
- │ ├── loader.py # Model loading: local HuggingFace or Vertex AI
158
  │ ├── prompts.py # System and user prompts for all agents
159
  │ ├── rag.py # ChromaDB ingestion and retrieval helpers
160
  │ ├── state.py # InfectionState TypedDict schema
 
47
  - HuggingFace account with access granted to:
48
  - [MedGemma](https://huggingface.co/google/medgemma-4b-it)
49
  - [TxGemma](https://huggingface.co/google/txgemma-2b-predict)
 
 
50
  ---
51
 
52
  ## Setup
 
66
  Edit `.env`. Minimum required settings:
67
 
68
  ```bash
 
 
 
69
  # Local model IDs (HuggingFace)
70
  MEDIC_LOCAL_MEDGEMMA_4B_MODEL=google/medgemma-4b-it
71
  MEDIC_LOCAL_MEDGEMMA_27B_MODEL=google/medgemma-4b-it # use 4B as fallback if <24 GB VRAM
 
73
  MEDIC_LOCAL_TXGEMMA_2B_MODEL=google/txgemma-2b-predict
74
  ```
75
 
 
 
 
 
 
 
 
 
 
76
  ### 3. Authenticate with HuggingFace
77
 
78
  ```bash
 
140
  ├── src/
141
  │ ├── agents.py # Four agent implementations
142
  │ ├── graph.py # LangGraph orchestrator + conditional routing
143
+ │ ├── loader.py # Model loading: local HuggingFace causal LMs
144
  │ ├── prompts.py # System and user prompts for all agents
145
  │ ├── rag.py # ChromaDB ingestion and retrieval helpers
146
  │ ├── state.py # InfectionState TypedDict schema
notebooks/kaggle_medic_demo.ipynb CHANGED
@@ -232,35 +232,7 @@
232
  "id": "a61f1fb1",
233
  "metadata": {},
234
  "outputs": [],
235
- "source": [
236
- "# Write .env\n",
237
- "env = f\"\"\"\n",
238
- "MEDIC_ENV=kaggle\n",
239
- "MEDIC_DEFAULT_BACKEND=local\n",
240
- "MEDIC_USE_VERTEX=false\n",
241
- "MEDIC_QUANTIZATION=4bit\n",
242
- "\n",
243
- "# Agent 1, 2, 4 — MedGemma 4B IT\n",
244
- "MEDIC_LOCAL_MEDGEMMA_4B_MODEL={MEDGEMMA_4B}\n",
245
- "\n",
246
- "# Agent 3 — MedGemma 27B Text IT (subbed with 4B for Kaggle T4)\n",
247
- "# To use full 27B: set to google/medgemma-27b-text-it\n",
248
- "MEDIC_LOCAL_MEDGEMMA_27B_MODEL={MEDGEMMA_4B}\n",
249
- "\n",
250
- "# Agent 4 safety — TxGemma 9B (subbed with 2B for Kaggle T4)\n",
251
- "# To use full 9B: set to google/txgemma-9b-predict\n",
252
- "MEDIC_LOCAL_TXGEMMA_9B_MODEL={TXGEMMA_2B}\n",
253
- "MEDIC_LOCAL_TXGEMMA_2B_MODEL={TXGEMMA_2B}\n",
254
- "\n",
255
- "MEDIC_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2\n",
256
- "MEDIC_DATA_DIR=/kaggle/working/AMR-Guard/data\n",
257
- "MEDIC_CHROMA_DB_DIR=/kaggle/working/AMR-Guard/data/chroma_db\n",
258
- "\"\"\".strip()\n",
259
- "\n",
260
- "with open(\"/kaggle/working/AMR-Guard/.env\", \"w\") as f:\n",
261
- " f.write(env)\n",
262
- "print(\".env written\")"
263
- ]
264
  },
265
  {
266
  "cell_type": "code",
@@ -683,4 +655,4 @@
683
  },
684
  "nbformat": 4,
685
  "nbformat_minor": 5
686
- }
 
232
  "id": "a61f1fb1",
233
  "metadata": {},
234
  "outputs": [],
235
+ "source": "# Write .env\nenv = f\"\"\"\nMEDIC_ENV=kaggle\nMEDIC_QUANTIZATION=4bit\n\n# Agent 1, 2, 4 — MedGemma 4B IT\nMEDIC_LOCAL_MEDGEMMA_4B_MODEL={MEDGEMMA_4B}\n\n# Agent 3 — MedGemma 27B Text IT (subbed with 4B for Kaggle T4)\n# To use full 27B: set to google/medgemma-27b-text-it\nMEDIC_LOCAL_MEDGEMMA_27B_MODEL={MEDGEMMA_4B}\n\n# Agent 4 safety — TxGemma 9B (subbed with 2B for Kaggle T4)\n# To use full 9B: set to google/txgemma-9b-predict\nMEDIC_LOCAL_TXGEMMA_9B_MODEL={TXGEMMA_2B}\nMEDIC_LOCAL_TXGEMMA_2B_MODEL={TXGEMMA_2B}\n\nMEDIC_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2\nMEDIC_DATA_DIR=/kaggle/working/AMR-Guard/data\nMEDIC_CHROMA_DB_DIR=/kaggle/working/AMR-Guard/data/chroma_db\n\"\"\".strip()\n\nwith open(\"/kaggle/working/AMR-Guard/.env\", \"w\") as f:\n f.write(env)\nprint(\".env written\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  },
237
  {
238
  "cell_type": "code",
 
655
  },
656
  "nbformat": 4,
657
  "nbformat_minor": 5
658
+ }
src/config.py CHANGED
@@ -18,7 +18,6 @@ class Settings(BaseModel):
18
  All configuration for AMR-Guard, read from environment variables.
19
 
20
  Supports three deployment targets via MEDIC_ENV: local, kaggle, production.
21
- Backend selection (vertex or local) is controlled by MEDIC_DEFAULT_BACKEND.
22
  """
23
 
24
  environment: Literal["local", "kaggle", "production"] = Field(
@@ -34,10 +33,7 @@ class Settings(BaseModel):
34
  default_factory=lambda: Path(os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db"))
35
  )
36
 
37
- default_backend: Literal["vertex", "local"] = Field(
38
- default_factory=lambda: os.getenv("MEDIC_DEFAULT_BACKEND", "local") # type: ignore[arg-type]
39
- )
40
- # 4-bit quantization via bitsandbytes (local backend only)
41
  quantization: Literal["none", "4bit"] = Field(
42
  default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
43
  )
@@ -45,47 +41,17 @@ class Settings(BaseModel):
45
  default_factory=lambda: os.getenv("MEDIC_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
46
  )
47
 
48
- # Vertex AI settings
49
- use_vertex: bool = Field(
50
- default_factory=lambda: os.getenv("MEDIC_USE_VERTEX", "true").lower() in {"1", "true", "yes"}
51
- )
52
- vertex_project_id: Optional[str] = Field(
53
- default_factory=lambda: os.getenv("MEDIC_VERTEX_PROJECT_ID")
54
- )
55
- vertex_location: str = Field(
56
- default_factory=lambda: os.getenv("MEDIC_VERTEX_LOCATION", "us-central1")
57
- )
58
- vertex_medgemma_4b_model: str = Field(
59
- default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_4B_MODEL", "med-gemma-4b-it")
60
- )
61
- vertex_medgemma_27b_model: str = Field(
62
- default_factory=lambda: os.getenv("MEDIC_VERTEX_MEDGEMMA_27B_MODEL", "med-gemma-27b-text-it")
63
- )
64
- vertex_txgemma_9b_model: str = Field(
65
- default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_9B_MODEL", "tx-gemma-9b")
66
- )
67
- vertex_txgemma_2b_model: str = Field(
68
- default_factory=lambda: os.getenv("MEDIC_VERTEX_TXGEMMA_2B_MODEL", "tx-gemma-2b")
69
- )
70
- google_application_credentials: Optional[Path] = Field(
71
- default_factory=lambda: (
72
- Path(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
73
- if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ
74
- else None
75
- )
76
- )
77
-
78
- # Local HuggingFace model paths (used when MEDIC_DEFAULT_BACKEND=local)
79
- local_medgemma_4b_model: Optional[str] = Field(
80
  default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
81
  )
82
- local_medgemma_27b_model: Optional[str] = Field(
83
  default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_27B_MODEL")
84
  )
85
- local_txgemma_9b_model: Optional[str] = Field(
86
  default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_9B_MODEL")
87
  )
88
- local_txgemma_2b_model: Optional[str] = Field(
89
  default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_2B_MODEL")
90
  )
91
 
@@ -94,4 +60,3 @@ class Settings(BaseModel):
94
  def get_settings() -> Settings:
95
  """Return the cached Settings singleton. Import this instead of instantiating Settings directly."""
96
  return Settings()
97
-
 
18
  All configuration for AMR-Guard, read from environment variables.
19
 
20
  Supports three deployment targets via MEDIC_ENV: local, kaggle, production.
 
21
  """
22
 
23
  environment: Literal["local", "kaggle", "production"] = Field(
 
33
  default_factory=lambda: Path(os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db"))
34
  )
35
 
36
+ # 4-bit quantization via bitsandbytes
 
 
 
37
  quantization: Literal["none", "4bit"] = Field(
38
  default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit") # type: ignore[arg-type]
39
  )
 
41
  default_factory=lambda: os.getenv("MEDIC_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
42
  )
43
 
44
+ # Local HuggingFace model paths
45
+ medgemma_4b_model: Optional[str] = Field(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
47
  )
48
+ medgemma_27b_model: Optional[str] = Field(
49
  default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_27B_MODEL")
50
  )
51
+ txgemma_9b_model: Optional[str] = Field(
52
  default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_9B_MODEL")
53
  )
54
+ txgemma_2b_model: Optional[str] = Field(
55
  default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_2B_MODEL")
56
  )
57
 
 
60
  def get_settings() -> Settings:
61
  """Return the cached Settings singleton. Import this instead of instantiating Settings directly."""
62
  return Settings()
 
src/loader.py CHANGED
@@ -7,58 +7,9 @@ from .config import get_settings
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- TextBackend = Literal["vertex", "local"]
11
  TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
12
 
13
 
14
- def _resolve_backend(requested: Optional[TextBackend]) -> TextBackend:
15
- settings = get_settings()
16
- backend = requested or settings.default_backend # type: ignore[assignment]
17
- if backend == "vertex" and not settings.use_vertex:
18
- logger.info("Vertex disabled in settings, falling back to local backend.")
19
- return "local"
20
- return backend
21
-
22
-
23
- @lru_cache(maxsize=8)
24
- def _get_vertex_chat_model(model_name: TextModelName):
25
- """Load a Vertex AI chat model and return a callable that takes a prompt string."""
26
- try:
27
- from langchain_google_vertexai import ChatVertexAI
28
- except Exception as exc:
29
- raise RuntimeError(
30
- "langchain-google-vertexai is not available; "
31
- "install it or switch MEDIC_DEFAULT_BACKEND=local."
32
- ) from exc
33
-
34
- settings = get_settings()
35
- if settings.vertex_project_id is None:
36
- raise RuntimeError(
37
- "MEDIC_VERTEX_PROJECT_ID is not set. "
38
- "Set it in your environment or .env to use Vertex AI."
39
- )
40
-
41
- model_id_map: Dict[TextModelName, str] = {
42
- "medgemma_4b": settings.vertex_medgemma_4b_model,
43
- "medgemma_27b": settings.vertex_medgemma_27b_model,
44
- "txgemma_9b": settings.vertex_txgemma_9b_model,
45
- "txgemma_2b": settings.vertex_txgemma_2b_model,
46
- }
47
-
48
- llm = ChatVertexAI(
49
- model=model_id_map[model_name],
50
- project=settings.vertex_project_id,
51
- location=settings.vertex_location,
52
- temperature=0.2,
53
- )
54
-
55
- def _call(prompt: str, **kwargs: Any) -> str:
56
- result = llm.invoke(prompt, **kwargs)
57
- return str(getattr(result, "content", result))
58
-
59
- return _call
60
-
61
-
62
  @lru_cache(maxsize=8)
63
  def _get_local_causal_lm(model_name: TextModelName):
64
  """Load a local HuggingFace causal LM and return a generation callable."""
@@ -67,17 +18,17 @@ def _get_local_causal_lm(model_name: TextModelName):
67
 
68
  settings = get_settings()
69
  model_path_map: Dict[TextModelName, Optional[str]] = {
70
- "medgemma_4b": settings.local_medgemma_4b_model,
71
- "medgemma_27b": settings.local_medgemma_27b_model,
72
- "txgemma_9b": settings.local_txgemma_9b_model,
73
- "txgemma_2b": settings.local_txgemma_2b_model,
74
  }
75
 
76
  model_path = model_path_map[model_name]
77
  if not model_path:
78
  raise RuntimeError(
79
  f"No local model path configured for {model_name}. "
80
- "Set MEDIC_LOCAL_*_MODEL or use the Vertex backend."
81
  )
82
 
83
  load_kwargs: Dict[str, Any] = {"device_map": "auto"}
@@ -108,22 +59,18 @@ def _get_local_causal_lm(model_name: TextModelName):
108
  @lru_cache(maxsize=32)
109
  def get_text_model(
110
  model_name: TextModelName = "medgemma_4b",
111
- backend: Optional[TextBackend] = None,
112
  ) -> Callable[..., str]:
113
- """Return a cached callable for the requested model and backend."""
114
- resolved = _resolve_backend(backend)
115
- return _get_vertex_chat_model(model_name) if resolved == "vertex" else _get_local_causal_lm(model_name)
116
 
117
 
118
  def run_inference(
119
  prompt: str,
120
  model_name: TextModelName = "medgemma_4b",
121
- backend: Optional[TextBackend] = None,
122
  max_new_tokens: int = 512,
123
  temperature: float = 0.2,
124
  **kwargs: Any,
125
  ) -> str:
126
  """Run inference with the specified model. This is the primary entry point for agents."""
127
- model = get_text_model(model_name=model_name, backend=backend)
128
  return model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
129
-
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
10
  TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @lru_cache(maxsize=8)
14
  def _get_local_causal_lm(model_name: TextModelName):
15
  """Load a local HuggingFace causal LM and return a generation callable."""
 
18
 
19
  settings = get_settings()
20
  model_path_map: Dict[TextModelName, Optional[str]] = {
21
+ "medgemma_4b": settings.medgemma_4b_model,
22
+ "medgemma_27b": settings.medgemma_27b_model,
23
+ "txgemma_9b": settings.txgemma_9b_model,
24
+ "txgemma_2b": settings.txgemma_2b_model,
25
  }
26
 
27
  model_path = model_path_map[model_name]
28
  if not model_path:
29
  raise RuntimeError(
30
  f"No local model path configured for {model_name}. "
31
+ f"Set MEDIC_LOCAL_*_MODEL in your environment or .env."
32
  )
33
 
34
  load_kwargs: Dict[str, Any] = {"device_map": "auto"}
 
59
  @lru_cache(maxsize=32)
60
  def get_text_model(
61
  model_name: TextModelName = "medgemma_4b",
 
62
  ) -> Callable[..., str]:
63
+ """Return a cached callable for the requested model."""
64
+ return _get_local_causal_lm(model_name)
 
65
 
66
 
67
  def run_inference(
68
  prompt: str,
69
  model_name: TextModelName = "medgemma_4b",
 
70
  max_new_tokens: int = 512,
71
  temperature: float = 0.2,
72
  **kwargs: Any,
73
  ) -> str:
74
  """Run inference with the specified model. This is the primary entry point for agents."""
75
+ model = get_text_model(model_name=model_name)
76
  return model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)