samwell Claude commited on
Commit
b2fc7a6
·
1 Parent(s): 27f1dea

Replace Gemini with MedGemma-4B as main orchestrator

Browse files

- Create ChatMedGemma LangChain wrapper with multimodal support
- Add MedGemma provider to ModelFactory with 4-bit quantization
- Update app.py to use MedGemma-4B instead of Gemini 2.0 Flash
- Benefits: Medical specialization (88.9% F1 on MIMIC-CXR), privacy, cost savings

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

app.py CHANGED
@@ -113,9 +113,11 @@ except Exception as e:
113
  checkpointer = MemorySaver()
114
 
115
  llm = ModelFactory.create_model(
116
- model_name="gemini-2.0-flash",
117
- temperature=0.7,
118
- max_tokens=5000
 
 
119
  )
120
 
121
  prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
@@ -167,7 +169,7 @@ def chat(message, history):
167
 
168
  # Custom interface with image output
169
  with gr.Blocks() as demo:
170
- gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded")
171
 
172
  chatbot = gr.Chatbot()
173
  viz_output = gr.Image(label="Grounding Visualization", visible=True)
 
113
  checkpointer = MemorySaver()
114
 
115
  llm = ModelFactory.create_model(
116
+ model_name="medgemma-4b-it",
117
+ temperature=1.0,
118
+ max_tokens=2048,
119
+ device=device,
120
+ load_in_4bit=True
121
  )
122
 
123
  prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
 
169
 
170
  # Custom interface with image output
171
  with gr.Blocks() as demo:
172
+ gr.Markdown(f"# MedRAX2 - Medical AI Assistant (MedGemma-4B)\n**Device:** {device} | **Tools:** {len(tools)} loaded")
173
 
174
  chatbot = gr.Chatbot()
175
  viz_output = gr.Image(label="Grounding Visualization", visible=True)
medrax/models/medgemma.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MedGemma model wrapper for LangChain compatibility."""
2
+ from typing import Any, List, Optional, Iterator
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
5
+ from langchain_core.language_models import BaseChatModel
6
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
7
+ from langchain_core.outputs import ChatGeneration, ChatResult
8
+ from langchain_core.callbacks import CallbackManagerForLLMRun
9
+
10
+
11
+ class ChatMedGemma(BaseChatModel):
12
+ """LangChain wrapper for MedGemma multimodal model."""
13
+
14
+ model: Any = None
15
+ processor: Any = None
16
+ model_name: str = "google/medgemma-4b-it"
17
+ device: str = "cuda"
18
+ max_new_tokens: int = 2048
19
+ temperature: float = 1.0
20
+ top_p: float = 0.95
21
+ top_k: int = 64
22
+
23
+ def __init__(
24
+ self,
25
+ model_name: str = "google/medgemma-4b-it",
26
+ device: str = "cuda",
27
+ load_in_4bit: bool = True,
28
+ max_new_tokens: int = 2048,
29
+ temperature: float = 1.0,
30
+ top_p: float = 0.95,
31
+ top_k: int = 64,
32
+ **kwargs
33
+ ):
34
+ """Initialize MedGemma model.
35
+
36
+ Args:
37
+ model_name: HuggingFace model name
38
+ device: Device to load model on (cuda/cpu)
39
+ load_in_4bit: Whether to use 4-bit quantization
40
+ max_new_tokens: Maximum tokens to generate
41
+ temperature: Sampling temperature
42
+ top_p: Top-p sampling parameter
43
+ top_k: Top-k sampling parameter
44
+ """
45
+ super().__init__(**kwargs)
46
+ self.model_name = model_name
47
+ self.device = device
48
+ self.max_new_tokens = max_new_tokens
49
+ self.temperature = temperature
50
+ self.top_p = top_p
51
+ self.top_k = top_k
52
+
53
+ # Setup quantization
54
+ if load_in_4bit and device == "cuda":
55
+ quantization_config = BitsAndBytesConfig(
56
+ load_in_4bit=True,
57
+ bnb_4bit_compute_dtype=torch.bfloat16,
58
+ bnb_4bit_use_double_quant=True,
59
+ bnb_4bit_quant_type="nf4",
60
+ )
61
+ else:
62
+ quantization_config = None
63
+
64
+ # Load model and processor
65
+ print(f"Loading MedGemma model: {model_name}...")
66
+ self.processor = AutoProcessor.from_pretrained(model_name)
67
+ self.model = AutoModelForImageTextToText.from_pretrained(
68
+ model_name,
69
+ device_map=device,
70
+ torch_dtype=torch.bfloat16,
71
+ quantization_config=quantization_config,
72
+ trust_remote_code=True,
73
+ ).eval()
74
+
75
+ # Enable sampling by default
76
+ self.model.generation_config.do_sample = True
77
+ print(f"✓ MedGemma model loaded successfully")
78
+
79
+ def _convert_messages_to_medgemma_format(self, messages: List[BaseMessage]) -> List[dict]:
80
+ """Convert LangChain messages to MedGemma format."""
81
+ converted_messages = []
82
+
83
+ for message in messages:
84
+ if isinstance(message, SystemMessage):
85
+ # MedGemma doesn't have system role, prepend to first user message
86
+ converted_messages.append({
87
+ "role": "system",
88
+ "content": [{"type": "text", "text": message.content}]
89
+ })
90
+ elif isinstance(message, HumanMessage):
91
+ content = []
92
+
93
+ # Handle multimodal content
94
+ if isinstance(message.content, list):
95
+ for item in message.content:
96
+ if isinstance(item, dict):
97
+ if item.get("type") == "image_url":
98
+ # Extract image path
99
+ image_url = item.get("image_url", {})
100
+ if isinstance(image_url, dict):
101
+ url = image_url.get("url", "")
102
+ else:
103
+ url = image_url
104
+ content.append({"type": "image", "url": url})
105
+ elif item.get("type") == "text":
106
+ content.append({"type": "text", "text": item.get("text", "")})
107
+ elif isinstance(item, str):
108
+ content.append({"type": "text", "text": item})
109
+ elif isinstance(message.content, str):
110
+ content = [{"type": "text", "text": message.content}]
111
+
112
+ converted_messages.append({"role": "user", "content": content})
113
+
114
+ elif isinstance(message, AIMessage):
115
+ converted_messages.append({
116
+ "role": "assistant",
117
+ "content": [{"type": "text", "text": message.content}]
118
+ })
119
+
120
+ return converted_messages
121
+
122
+ def _generate(
123
+ self,
124
+ messages: List[BaseMessage],
125
+ stop: Optional[List[str]] = None,
126
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
127
+ **kwargs: Any,
128
+ ) -> ChatResult:
129
+ """Generate a response from MedGemma."""
130
+ # Convert messages to MedGemma format
131
+ medgemma_messages = self._convert_messages_to_medgemma_format(messages)
132
+
133
+ # Apply chat template
134
+ inputs = self.processor.apply_chat_template(
135
+ medgemma_messages,
136
+ add_generation_prompt=True,
137
+ tokenize=True,
138
+ return_dict=True,
139
+ return_tensors="pt",
140
+ ).to(device=self.model.device, dtype=torch.bfloat16)
141
+
142
+ # Generate response
143
+ with torch.inference_mode():
144
+ output_ids = self.model.generate(
145
+ **inputs,
146
+ max_new_tokens=self.max_new_tokens,
147
+ do_sample=True,
148
+ temperature=self.temperature,
149
+ top_p=self.top_p,
150
+ top_k=self.top_k,
151
+ pad_token_id=self.processor.tokenizer.eos_token_id,
152
+ )
153
+
154
+ # Decode output
155
+ prompt_length = inputs["input_ids"].shape[-1]
156
+ generated_ids = output_ids[0][prompt_length:]
157
+ response_text = self.processor.decode(
158
+ generated_ids,
159
+ skip_special_tokens=True,
160
+ clean_up_tokenization_spaces=True
161
+ )
162
+
163
+ # Create ChatGeneration
164
+ message = AIMessage(content=response_text)
165
+ generation = ChatGeneration(message=message)
166
+
167
+ return ChatResult(generations=[generation])
168
+
169
+ @property
170
+ def _llm_type(self) -> str:
171
+ """Return type of LLM."""
172
+ return "medgemma"
173
+
174
+ @property
175
+ def _identifying_params(self) -> dict:
176
+ """Return identifying parameters."""
177
+ return {
178
+ "model_name": self.model_name,
179
+ "device": self.device,
180
+ "max_new_tokens": self.max_new_tokens,
181
+ "temperature": self.temperature,
182
+ "top_p": self.top_p,
183
+ "top_k": self.top_k,
184
+ }
medrax/models/model_factory.py CHANGED
@@ -7,6 +7,7 @@ from langchain_core.language_models import BaseLanguageModel
7
  from langchain_openai import ChatOpenAI
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_xai import ChatXAI
 
10
 
11
 
12
  class ModelFactory:
@@ -43,6 +44,11 @@ class ModelFactory:
43
  "class": ChatXAI,
44
  "env_key": "XAI_API_KEY",
45
  },
 
 
 
 
 
46
  # Add more providers with default configurations here
47
  }
48
 
@@ -90,16 +96,18 @@ class ModelFactory:
90
  provider = cls._model_providers[provider_prefix]
91
  model_class = provider["class"]
92
  env_key = provider["env_key"]
 
93
 
94
  # Set up provider-specific kwargs
95
  provider_kwargs = {}
96
 
97
- # Handle API key
98
- if env_key in os.environ:
99
- provider_kwargs["api_key"] = os.environ[env_key]
100
- else:
101
- # Log warning but don't fail - the model class might handle missing API keys differently
102
- print(f"Warning: Environment variable {env_key} not found. Authentication may fail.")
 
103
 
104
  # Check for base_url if applicable
105
  if "base_url_key" in provider:
@@ -131,6 +139,19 @@ class ModelFactory:
131
  **kwargs,
132
  )
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # Create and return the model instance
135
  return model_class(
136
  model=actual_model_name,
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_xai import ChatXAI
10
+ from .medgemma import ChatMedGemma
11
 
12
 
13
  class ModelFactory:
 
44
  "class": ChatXAI,
45
  "env_key": "XAI_API_KEY",
46
  },
47
+ "medgemma": {
48
+ "class": ChatMedGemma,
49
+ "env_key": None, # Local model, no API key needed
50
+ "is_local": True,
51
+ },
52
  # Add more providers with default configurations here
53
  }
54
 
 
96
  provider = cls._model_providers[provider_prefix]
97
  model_class = provider["class"]
98
  env_key = provider["env_key"]
99
+ is_local = provider.get("is_local", False)
100
 
101
  # Set up provider-specific kwargs
102
  provider_kwargs = {}
103
 
104
+ # Handle API key (skip for local models)
105
+ if not is_local:
106
+ if env_key and env_key in os.environ:
107
+ provider_kwargs["api_key"] = os.environ[env_key]
108
+ elif env_key:
109
+ # Log warning but don't fail - the model class might handle missing API keys differently
110
+ print(f"Warning: Environment variable {env_key} not found. Authentication may fail.")
111
 
112
  # Check for base_url if applicable
113
  if "base_url_key" in provider:
 
139
  **kwargs,
140
  )
141
 
142
+ # Handle MedGemma (local model with different parameter names)
143
+ if model_name.startswith("medgemma"):
144
+ return model_class(
145
+ model_name=actual_model_name,
146
+ temperature=temperature,
147
+ top_p=top_p,
148
+ top_k=kwargs.get("top_k", 64),
149
+ max_new_tokens=max_tokens,
150
+ device=kwargs.get("device", "cuda"),
151
+ load_in_4bit=kwargs.get("load_in_4bit", True),
152
+ **provider_kwargs,
153
+ )
154
+
155
  # Create and return the model instance
156
  return model_class(
157
  model=actual_model_name,