ghitaben commited on
Commit
18c0556
·
1 Parent(s): 4e19176

add loggs

Browse files
Files changed (3) hide show
  1. notebooks/kaggle_medic_demo.ipynb +104 -16
  2. src/graph.py +14 -4
  3. src/loader.py +89 -9
notebooks/kaggle_medic_demo.ipynb CHANGED
@@ -12,16 +12,16 @@
12
  "|---|---|---|\n",
13
  "| 1 · Intake Historian | Patient data, CrCl, MDR risk | MedGemma 4B IT |\n",
14
  "| 2 · Vision Specialist | Lab report → structured JSON | MedGemma 4B IT |\n",
15
- "| 3 · Trend Analyst | MIC creep, resistance velocity | MedGemma 27B Text IT ¹ |\n",
16
- "| 4 · Clinical Pharmacologist | Final Rx + safety check | MedGemma 4B IT + TxGemma 9B ¹ |\n",
17
- "\n",
18
- "> ¹ Substituted with smaller variants on Kaggle T4 (16 GB GPU) — see Section 3.\n",
19
  "\n",
20
  "**Before running this notebook:**\n",
21
  "1. Click **Add data** (top-right) → search for **`mghobashy/drug-drug-interactions`** → add it\n",
22
  "2. Add your HuggingFace token under **Add-ons → Secrets** as `HF_TOKEN`\n",
23
  "3. Accept model licences on HuggingFace (see Section 2)\n",
24
  "\n",
 
 
25
  "**Steps:** Clone → Install → Authenticate → Download models → Init KB → Launch app"
26
  ]
27
  },
@@ -69,7 +69,12 @@
69
  "id": "4c637bc0",
70
  "metadata": {},
71
  "outputs": [],
72
- "source": "%%bash\n# Always start fresh to avoid stale code from previous runs\nrm -rf /kaggle/working/AMR-Guard\ngit clone \"$GITHUB_REPO\" /kaggle/working/AMR-Guard"
 
 
 
 
 
73
  },
74
  {
75
  "cell_type": "code",
@@ -98,7 +103,8 @@
98
  "\n",
99
  "Accept the model licences **before** running this notebook:\n",
100
  "- MedGemma 4B IT → https://huggingface.co/google/medgemma-4b-it\n",
101
- "- TxGemma 2B → https://huggingface.co/google/txgemma-2b-predict"
 
102
  ]
103
  },
104
  {
@@ -130,15 +136,13 @@
130
  "source": [
131
  "## 3 · Download Models\n",
132
  "\n",
133
- "| Model | Agent | VRAM (4-bit) | Kaggle T4 |\n",
134
- "|---|---|---|---|\n",
135
- "| `google/medgemma-4b-it` | 1, 2, 4 primary | ~3 GB | ✓ |\n",
136
- "| `google/medgemma-27b-text-it` | 3 (Trend Analyst) | ~14 GB | marginal — using 4B sub |\n",
137
- "| `google/txgemma-9b-predict` | 4 safety check | ~5 GB | ✓ (stacked with 4B: ~8 GB) |\n",
138
- "| `google/txgemma-2b-predict` | 4 safety fallback | ~1.5 GB | ✓ |\n",
139
  "\n",
140
- "**Kaggle strategy:** download `medgemma-4b-it` and `txgemma-2b-predict`. \n",
141
- "Swap to the full 27B / 9B on a machine with ≥ 24 GB VRAM by editing the variables below."
142
  ]
143
  },
144
  {
@@ -224,7 +228,33 @@
224
  "id": "a61f1fb1",
225
  "metadata": {},
226
  "outputs": [],
227
- "source": "# Write .env\nenv = f\"\"\"\nMEDIC_ENV=kaggle\nMEDIC_QUANTIZATION=4bit\n\n# Agent 1, 2, 4 — MedGemma 4B IT\nMEDIC_LOCAL_MEDGEMMA_4B_MODEL={MEDGEMMA_4B}\n\n# Agent 3 — MedGemma 27B Text IT (subbed with 4B for Kaggle T4)\n# To use full 27B: set to google/medgemma-27b-text-it\nMEDIC_LOCAL_MEDGEMMA_27B_MODEL={MEDGEMMA_4B}\n\n# Agent 4 safety — TxGemma 9B (subbed with 2B for Kaggle T4)\n# To use full 9B: set to google/txgemma-9b-predict\nMEDIC_LOCAL_TXGEMMA_9B_MODEL={TXGEMMA_2B}\nMEDIC_LOCAL_TXGEMMA_2B_MODEL={TXGEMMA_2B}\n\nMEDIC_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2\nMEDIC_DATA_DIR=/kaggle/working/AMR-Guard/data\nMEDIC_CHROMA_DB_DIR=/kaggle/working/AMR-Guard/data/chroma_db\n\"\"\".strip()\n\nwith open(\"/kaggle/working/AMR-Guard/.env\", \"w\") as f:\n f.write(env)\nprint(\".env written\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  },
229
  {
230
  "cell_type": "code",
@@ -467,6 +497,7 @@
467
  " \"\"\"Stage 1 — Empirical: no lab results.\n",
468
  " Active models: MedGemma 4B (Agent 1) → MedGemma 4B + TxGemma 2B (Agent 4).\n",
469
  " \"\"\"\n",
 
470
  " patient_data = _build_patient_data(\n",
471
  " age, weight, height, sex, creatinine,\n",
472
  " infection_site, suspected_source,\n",
@@ -474,8 +505,11 @@
474
  " )\n",
475
  " try:\n",
476
  " from src.graph import run_pipeline\n",
 
477
  " result = run_pipeline(patient_data, labs_raw_text=None)\n",
 
478
  " except Exception as exc:\n",
 
479
  " result = _demo_result(patient_data, None)\n",
480
  " result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n",
481
  " return format_recommendation(result)\n",
@@ -489,6 +523,7 @@
489
  " Active models: MedGemma 4B (Agents 1, 2) → MedGemma 4B→27B sub (Agent 3)\n",
490
  " → MedGemma 4B + TxGemma 2B (Agent 4).\n",
491
  " \"\"\"\n",
 
492
  " patient_data = _build_patient_data(\n",
493
  " age, weight, height, sex, creatinine,\n",
494
  " infection_site, suspected_source,\n",
@@ -497,8 +532,11 @@
497
  " labs = labs_text.strip() if labs_text else None\n",
498
  " try:\n",
499
  " from src.graph import run_pipeline\n",
 
500
  " result = run_pipeline(patient_data, labs_raw_text=labs)\n",
 
501
  " except Exception as exc:\n",
 
502
  " result = _demo_result(patient_data, labs)\n",
503
  " result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n",
504
  " return format_recommendation(result), format_lab_analysis(result)\n",
@@ -507,6 +545,56 @@
507
  "print(\"Helper functions loaded.\")"
508
  ]
509
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  {
511
  "cell_type": "code",
512
  "execution_count": null,
@@ -647,4 +735,4 @@
647
  },
648
  "nbformat": 4,
649
  "nbformat_minor": 5
650
- }
 
12
  "|---|---|---|\n",
13
  "| 1 · Intake Historian | Patient data, CrCl, MDR risk | MedGemma 4B IT |\n",
14
  "| 2 · Vision Specialist | Lab report → structured JSON | MedGemma 4B IT |\n",
15
+ "| 3 · Trend Analyst | MIC creep, resistance velocity | MedGemma 27B Text IT |\n",
16
+ "| 4 · Clinical Pharmacologist | Final Rx + safety check | MedGemma 4B IT + TxGemma 9B |\n",
 
 
17
  "\n",
18
  "**Before running this notebook:**\n",
19
  "1. Click **Add data** (top-right) → search for **`mghobashy/drug-drug-interactions`** → add it\n",
20
  "2. Add your HuggingFace token under **Add-ons → Secrets** as `HF_TOKEN`\n",
21
  "3. Accept model licences on HuggingFace (see Section 2)\n",
22
  "\n",
23
+ "**Requirements:** GPU with ≥ 30 GB VRAM (e.g. Kaggle P100 / A100 / T4 ×2)\n",
24
+ "\n",
25
  "**Steps:** Clone → Install → Authenticate → Download models → Init KB → Launch app"
26
  ]
27
  },
 
69
  "id": "4c637bc0",
70
  "metadata": {},
71
  "outputs": [],
72
+ "source": [
73
+ "%%bash\n",
74
+ "# Always start fresh to avoid stale code from previous runs\n",
75
+ "rm -rf /kaggle/working/AMR-Guard\n",
76
+ "git clone \"$GITHUB_REPO\" /kaggle/working/AMR-Guard"
77
+ ]
78
  },
79
  {
80
  "cell_type": "code",
 
103
  "\n",
104
  "Accept the model licences **before** running this notebook:\n",
105
  "- MedGemma 4B IT → https://huggingface.co/google/medgemma-4b-it\n",
106
+ "- MedGemma 27B Text IT → https://huggingface.co/google/medgemma-27b-text-it\n",
107
+ "- TxGemma 9B → https://huggingface.co/google/txgemma-9b-predict"
108
  ]
109
  },
110
  {
 
136
  "source": [
137
  "## 3 · Download Models\n",
138
  "\n",
139
+ "| Model | Agent | VRAM (4-bit) |\n",
140
+ "|---|---|---|\n",
141
+ "| `google/medgemma-4b-it` | 1, 2, 4 primary | ~3 GB |\n",
142
+ "| `google/medgemma-27b-text-it` | 3 (Trend Analyst) | ~14 GB |\n",
143
+ "| `google/txgemma-9b-predict` | 4 safety check | ~5 GB |\n",
 
144
  "\n",
145
+ "**Total estimated VRAM:** ~22 GB in 4-bit quantization."
 
146
  ]
147
  },
148
  {
 
228
  "id": "a61f1fb1",
229
  "metadata": {},
230
  "outputs": [],
231
+ "source": [
232
+ "# Write .env\n",
233
+ "env = f\"\"\"\n",
234
+ "MEDIC_ENV=kaggle\n",
235
+ "MEDIC_QUANTIZATION=4bit\n",
236
+ "\n",
237
+ "# Agent 1, 2, 4 — MedGemma 4B IT\n",
238
+ "MEDIC_LOCAL_MEDGEMMA_4B_MODEL={MEDGEMMA_4B}\n",
239
+ "\n",
240
+ "# Agent 3 — MedGemma 27B Text IT (subbed with 4B for Kaggle T4)\n",
241
+ "# To use full 27B: set to google/medgemma-27b-text-it\n",
242
+ "MEDIC_LOCAL_MEDGEMMA_27B_MODEL={MEDGEMMA_4B}\n",
243
+ "\n",
244
+ "# Agent 4 safety — TxGemma 9B (subbed with 2B for Kaggle T4)\n",
245
+ "# To use full 9B: set to google/txgemma-9b-predict\n",
246
+ "MEDIC_LOCAL_TXGEMMA_9B_MODEL={TXGEMMA_2B}\n",
247
+ "MEDIC_LOCAL_TXGEMMA_2B_MODEL={TXGEMMA_2B}\n",
248
+ "\n",
249
+ "MEDIC_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2\n",
250
+ "MEDIC_DATA_DIR=/kaggle/working/AMR-Guard/data\n",
251
+ "MEDIC_CHROMA_DB_DIR=/kaggle/working/AMR-Guard/data/chroma_db\n",
252
+ "\"\"\".strip()\n",
253
+ "\n",
254
+ "with open(\"/kaggle/working/AMR-Guard/.env\", \"w\") as f:\n",
255
+ " f.write(env)\n",
256
+ "print(\".env written\")"
257
+ ]
258
  },
259
  {
260
  "cell_type": "code",
 
497
  " \"\"\"Stage 1 — Empirical: no lab results.\n",
498
  " Active models: MedGemma 4B (Agent 1) → MedGemma 4B + TxGemma 2B (Agent 4).\n",
499
  " \"\"\"\n",
500
+ " logger.info(f\"Starting empirical scenario: {infection_site} infection\")\n",
501
  " patient_data = _build_patient_data(\n",
502
  " age, weight, height, sex, creatinine,\n",
503
  " infection_site, suspected_source,\n",
 
505
  " )\n",
506
  " try:\n",
507
  " from src.graph import run_pipeline\n",
508
+ " logger.info(\"Calling run_pipeline...\")\n",
509
  " result = run_pipeline(patient_data, labs_raw_text=None)\n",
510
+ " logger.info(\"Pipeline completed successfully\")\n",
511
  " except Exception as exc:\n",
512
+ " logger.error(f\"Pipeline failed: {exc}\", exc_info=True)\n",
513
  " result = _demo_result(patient_data, None)\n",
514
  " result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n",
515
  " return format_recommendation(result)\n",
 
523
  " Active models: MedGemma 4B (Agents 1, 2) → MedGemma 4B→27B sub (Agent 3)\n",
524
  " → MedGemma 4B + TxGemma 2B (Agent 4).\n",
525
  " \"\"\"\n",
526
+ " logger.info(f\"Starting targeted scenario: {infection_site} infection with lab data\")\n",
527
  " patient_data = _build_patient_data(\n",
528
  " age, weight, height, sex, creatinine,\n",
529
  " infection_site, suspected_source,\n",
 
532
  " labs = labs_text.strip() if labs_text else None\n",
533
  " try:\n",
534
  " from src.graph import run_pipeline\n",
535
+ " logger.info(\"Calling run_pipeline with lab data...\")\n",
536
  " result = run_pipeline(patient_data, labs_raw_text=labs)\n",
537
+ " logger.info(\"Pipeline completed successfully\")\n",
538
  " except Exception as exc:\n",
539
+ " logger.error(f\"Pipeline failed: {exc}\", exc_info=True)\n",
540
  " result = _demo_result(patient_data, labs)\n",
541
  " result[\"errors\"].append(f\"[Demo mode — pipeline error: {exc}]\")\n",
542
  " return format_recommendation(result), format_lab_analysis(result)\n",
 
545
  "print(\"Helper functions loaded.\")"
546
  ]
547
  },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": null,
551
+ "id": "57203904",
552
+ "metadata": {},
553
+ "outputs": [],
554
+ "source": [
555
+ "# Enable logging to see what's happening during pipeline execution\n",
556
+ "import logging\n",
557
+ "logging.basicConfig(\n",
558
+ " level=logging.INFO,\n",
559
+ " format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'\n",
560
+ ")\n",
561
+ "logger = logging.getLogger(__name__)\n",
562
+ "logger.info(\"Logging enabled for pipeline debugging\")"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": null,
568
+ "id": "69c8a263",
569
+ "metadata": {},
570
+ "outputs": [],
571
+ "source": [
572
+ "# Test model configuration and availability\n",
573
+ "print(\"Testing model configuration...\")\n",
574
+ "from src.config import get_settings\n",
575
+ "from src.loader import _get_model_path, _is_multimodal\n",
576
+ "\n",
577
+ "settings = get_settings()\n",
578
+ "print(f\"Environment: {settings.environment}\")\n",
579
+ "print(f\"Quantization: {settings.quantization}\")\n",
580
+ "print(f\"\\nConfigured models:\")\n",
581
+ "print(f\" MedGemma 4B: {settings.medgemma_4b_model}\")\n",
582
+ "print(f\" MedGemma 27B: {settings.medgemma_27b_model}\")\n",
583
+ "print(f\" TxGemma 2B: {settings.txgemma_2b_model}\")\n",
584
+ "print(f\" TxGemma 9B: {settings.txgemma_9b_model}\")\n",
585
+ "\n",
586
+ "print(f\"\\nModel architectures:\")\n",
587
+ "for model_name in [\"medgemma_4b\", \"txgemma_2b\"]:\n",
588
+ " try:\n",
589
+ " path = _get_model_path(model_name)\n",
590
+ " is_mm = _is_multimodal(path)\n",
591
+ " print(f\" {model_name}: {'multimodal' if is_mm else 'causal LM'}\")\n",
592
+ " except Exception as e:\n",
593
+ " print(f\" {model_name}: ERROR - {e}\")\n",
594
+ "\n",
595
+ "print(\"\\n✓ Configuration validated\")"
596
+ ]
597
+ },
598
  {
599
  "cell_type": "code",
600
  "execution_count": null,
 
735
  },
736
  "nbformat": 4,
737
  "nbformat_minor": 5
738
+ }
src/graph.py CHANGED
@@ -91,16 +91,26 @@ def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> Infect
91
  "country_or_region": patient_data.get("country_or_region"),
92
  "vitals": patient_data.get("vitals", {}),
93
  "stage": "targeted" if labs_raw_text else "empirical",
 
 
94
  }
95
 
96
  if labs_raw_text:
97
  initial_state["labs_raw_text"] = labs_raw_text
98
 
99
  logger.info(f"Starting pipeline (stage: {initial_state['stage']})")
100
- compiled = build_infection_graph().compile()
101
- final_state = compiled.invoke(initial_state)
102
- logger.info("Pipeline complete")
103
- return final_state
 
 
 
 
 
 
 
 
104
 
105
 
106
  def run_empirical_pipeline(patient_data: dict) -> InfectionState:
 
91
  "country_or_region": patient_data.get("country_or_region"),
92
  "vitals": patient_data.get("vitals", {}),
93
  "stage": "targeted" if labs_raw_text else "empirical",
94
+ "errors": [],
95
+ "safety_warnings": [],
96
  }
97
 
98
  if labs_raw_text:
99
  initial_state["labs_raw_text"] = labs_raw_text
100
 
101
  logger.info(f"Starting pipeline (stage: {initial_state['stage']})")
102
+ logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
103
+
104
+ try:
105
+ compiled = build_infection_graph().compile()
106
+ logger.info("Graph compiled successfully")
107
+ final_state = compiled.invoke(initial_state)
108
+ logger.info("Pipeline complete")
109
+ return final_state
110
+ except Exception as e:
111
+ logger.error(f"Pipeline execution failed: {e}", exc_info=True)
112
+ initial_state["errors"].append(f"Pipeline error: {str(e)}")
113
+ return initial_state
114
 
115
 
116
  def run_empirical_pipeline(patient_data: dict) -> InfectionState:
src/loader.py CHANGED
@@ -9,13 +9,15 @@ logger = logging.getLogger(__name__)
9
 
10
  TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
11
 
 
 
 
 
 
 
12
 
13
- @lru_cache(maxsize=8)
14
- def _get_local_causal_lm(model_name: TextModelName):
15
- """Load a local HuggingFace causal LM and return a generation callable."""
16
- from transformers import AutoModelForCausalLM, AutoTokenizer
17
- import torch
18
 
 
19
  settings = get_settings()
20
  model_path_map: Dict[TextModelName, Optional[str]] = {
21
  "medgemma_4b": settings.medgemma_4b_model,
@@ -23,21 +25,77 @@ def _get_local_causal_lm(model_name: TextModelName):
23
  "txgemma_9b": settings.txgemma_9b_model,
24
  "txgemma_2b": settings.txgemma_2b_model,
25
  }
26
-
27
  model_path = model_path_map[model_name]
28
  if not model_path:
29
  raise RuntimeError(
30
  f"No local model path configured for {model_name}. "
31
  f"Set MEDIC_LOCAL_*_MODEL in your environment or .env."
32
  )
 
33
 
 
 
 
34
  load_kwargs: Dict[str, Any] = {"device_map": "auto"}
35
  if settings.quantization == "4bit":
36
  from transformers import BitsAndBytesConfig
37
  load_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
40
  model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
 
41
 
42
  def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
43
  inputs = {k: v.to(model.device) for k, v in tokenizer(prompt, return_tensors="pt").items()}
@@ -46,7 +104,7 @@ def _get_local_causal_lm(model_name: TextModelName):
46
  output_ids = model.generate(
47
  **inputs,
48
  do_sample=do_sample,
49
- temperature=temperature if do_sample else 0.0,
50
  max_new_tokens=max_new_tokens,
51
  **generate_kwargs,
52
  )
@@ -57,11 +115,25 @@ def _get_local_causal_lm(model_name: TextModelName):
57
  return _call
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
60
  @lru_cache(maxsize=32)
61
  def get_text_model(
62
  model_name: TextModelName = "medgemma_4b",
63
  ) -> Callable[..., str]:
64
  """Return a cached callable for the requested model."""
 
 
 
65
  return _get_local_causal_lm(model_name)
66
 
67
 
@@ -73,5 +145,13 @@ def run_inference(
73
  **kwargs: Any,
74
  ) -> str:
75
  """Run inference with the specified model. This is the primary entry point for agents."""
76
- model = get_text_model(model_name=model_name)
77
- return model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
 
 
 
 
 
 
 
 
 
9
 
10
  TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
11
 
12
+ # MedGemma 4B IT is a vision-language model (Gemma3ForConditionalGeneration).
13
+ # It must be loaded with AutoModelForImageTextToText + AutoProcessor.
14
+ # All other models (medgemma-27b-text-it, txgemma-*) are causal LMs.
15
+ # On Kaggle T4, medgemma_27b is substituted with medgemma-4b-it (also multimodal),
16
+ # so we detect the architecture dynamically from the model config.
17
+ _MULTIMODAL_ARCHITECTURES = {"Gemma3ForConditionalGeneration"}
18
 
 
 
 
 
 
19
 
20
+ def _get_model_path(model_name: TextModelName) -> str:
21
  settings = get_settings()
22
  model_path_map: Dict[TextModelName, Optional[str]] = {
23
  "medgemma_4b": settings.medgemma_4b_model,
 
25
  "txgemma_9b": settings.txgemma_9b_model,
26
  "txgemma_2b": settings.txgemma_2b_model,
27
  }
 
28
  model_path = model_path_map[model_name]
29
  if not model_path:
30
  raise RuntimeError(
31
  f"No local model path configured for {model_name}. "
32
  f"Set MEDIC_LOCAL_*_MODEL in your environment or .env."
33
  )
34
+ return model_path
35
 
36
+
37
+ def _get_load_kwargs() -> Dict[str, Any]:
38
+ settings = get_settings()
39
  load_kwargs: Dict[str, Any] = {"device_map": "auto"}
40
  if settings.quantization == "4bit":
41
  from transformers import BitsAndBytesConfig
42
  load_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
43
+ return load_kwargs
44
+
45
+
46
+ @lru_cache(maxsize=8)
47
+ def _get_local_multimodal(model_name: TextModelName):
48
+ """Load a multimodal model (e.g. MedGemma 4B IT) and return a text generation callable."""
49
+ from transformers import AutoModelForImageTextToText, AutoProcessor
50
+ import torch
51
+
52
+ model_path = _get_model_path(model_name)
53
+ load_kwargs = _get_load_kwargs()
54
+
55
+ logger.info(f"Loading multimodal model: {model_path} with kwargs: {load_kwargs}")
56
+ processor = AutoProcessor.from_pretrained(model_path)
57
+ logger.info(f"Processor loaded for {model_path}")
58
+ model = AutoModelForImageTextToText.from_pretrained(model_path, **load_kwargs)
59
+ logger.info(f"Model loaded successfully: {model_path}")
60
+
61
+ def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
62
+ # Build a chat-style input for text-only queries
63
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
64
+ inputs = processor.apply_chat_template(
65
+ messages, add_generation_prompt=True, tokenize=True,
66
+ return_dict=True, return_tensors="pt",
67
+ ).to(model.device)
68
+
69
+ do_sample = temperature > 0
70
+ with torch.no_grad():
71
+ output_ids = model.generate(
72
+ **inputs,
73
+ do_sample=do_sample,
74
+ temperature=temperature if do_sample else None,
75
+ max_new_tokens=max_new_tokens,
76
+ **generate_kwargs,
77
+ )
78
+ # Decode only the newly generated tokens
79
+ generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
80
+ return processor.decode(generated_ids, skip_special_tokens=True).strip()
81
+
82
+ return _call
83
 
84
+
85
+ @lru_cache(maxsize=8)
86
+ def _get_local_causal_lm(model_name: TextModelName):
87
+ """Load a causal LM (e.g. TxGemma, MedGemma 27B text) and return a generation callable."""
88
+ from transformers import AutoModelForCausalLM, AutoTokenizer
89
+ import torch
90
+
91
+ model_path = _get_model_path(model_name)
92
+ load_kwargs = _get_load_kwargs()
93
+
94
+ logger.info(f"Loading causal LM: {model_path} with kwargs: {load_kwargs}")
95
  tokenizer = AutoTokenizer.from_pretrained(model_path)
96
+ logger.info(f"Tokenizer loaded for {model_path}")
97
  model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
98
+ logger.info(f"Model loaded successfully: {model_path}")
99
 
100
  def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
101
  inputs = {k: v.to(model.device) for k, v in tokenizer(prompt, return_tensors="pt").items()}
 
104
  output_ids = model.generate(
105
  **inputs,
106
  do_sample=do_sample,
107
+ temperature=temperature if do_sample else None,
108
  max_new_tokens=max_new_tokens,
109
  **generate_kwargs,
110
  )
 
115
  return _call
116
 
117
 
118
+ def _is_multimodal(model_path: str) -> bool:
119
+ """Check if a model uses a multimodal architecture by inspecting its config."""
120
+ from transformers import AutoConfig
121
+ try:
122
+ config = AutoConfig.from_pretrained(model_path)
123
+ architectures = getattr(config, "architectures", []) or []
124
+ return bool(set(architectures) & _MULTIMODAL_ARCHITECTURES)
125
+ except Exception:
126
+ return False
127
+
128
+
129
  @lru_cache(maxsize=32)
130
  def get_text_model(
131
  model_name: TextModelName = "medgemma_4b",
132
  ) -> Callable[..., str]:
133
  """Return a cached callable for the requested model."""
134
+ model_path = _get_model_path(model_name)
135
+ if _is_multimodal(model_path):
136
+ return _get_local_multimodal(model_name)
137
  return _get_local_causal_lm(model_name)
138
 
139
 
 
145
  **kwargs: Any,
146
  ) -> str:
147
  """Run inference with the specified model. This is the primary entry point for agents."""
148
+ logger.info(f"Running inference with {model_name}, max_tokens={max_new_tokens}, temp={temperature}")
149
+ try:
150
+ model = get_text_model(model_name=model_name)
151
+ logger.info(f"Model {model_name} loaded successfully")
152
+ result = model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
153
+ logger.info(f"Inference complete, response length: {len(result)} chars")
154
+ return result
155
+ except Exception as e:
156
+ logger.error(f"Inference failed for {model_name}: {e}", exc_info=True)
157
+ raise