File size: 5,680 Bytes
7529164 cc92cee 1b65fec cc92cee ebaa9cf cc92cee 7529164 1fdf84c 7529164 a6956fd ebaa9cf 7529164 1fdf84c a6956fd ebaa9cf 7529164 a6956fd 7529164 a6956fd 7529164 a6956fd 7529164 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """
Model loading and inference utilities
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import logging
from . import config
logger = logging.getLogger(__name__)
import os
from huggingface_hub import login
token = os.getenv("HF_TOKEN")
if token:
login(token=token)
logger.info("HuggingFace login successful")
else:
logger.warning("HF_TOKEN not found — model download will fail if MedGemma is gated")
import re
from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
class MedGemmaGenerator:
"""Wrapper for MedGemma 1.5 4B model"""
def __init__(self):
logger.info(f"Loading MedGemma model: {config.MEDGEMMA_MODEL_ID}")
self.processor = AutoProcessor.from_pretrained(config.MEDGEMMA_MODEL_ID, token=os.getenv("HF_TOKEN"),)
self.model = AutoModelForImageTextToText.from_pretrained(
config.MEDGEMMA_MODEL_ID,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
token=os.getenv("HF_TOKEN"),
)
self.model = self.model.to("cpu")
self.model.eval()
logger.info("MedGemma model loaded successfully")
def _strip_thinking_block(self, text: str) -> str:
"""
Remove the thinking/reasoning block that Gemma 3-based models emit.
MedGemma 1.5 uses <unused94>thought...<unused95> tokens.
"""
text = re.sub(
r"<unused94>thought[\s\S]*?<unused95>",
"",
text,
flags=re.IGNORECASE,
)
text = re.sub(
r"<think>[\s\S]*?</think>",
"",
text,
flags=re.IGNORECASE,
)
text = re.sub(
r"<unused94>thought[\s\S]*$",
"",
text,
flags=re.IGNORECASE,
)
text = re.sub(
r"<think>[\s\S]*$",
"",
text,
flags=re.IGNORECASE,
)
return text.strip()
def generate(self, prompt: str, max_new_tokens: int = None) -> str:
"""
Generate text from prompt using MedGemma
Args:
prompt: Input prompt
max_new_tokens: Override default max tokens if provided
Returns:
Generated text (with thinking block removed)
"""
gen_config = config.GENERATION_CONFIG.copy()
if max_new_tokens:
gen_config["max_new_tokens"] = max_new_tokens
# Use proper message format for MedGemma 1.5 4B-IT
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
]
# Apply chat template properly
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to("cpu") # no dtype cast — float32 stays as-is
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**gen_config,
pad_token_id=self.processor.tokenizer.pad_token_id if hasattr(self.processor, 'tokenizer') else self.processor.pad_token_id,
)
# Extract only the generated portion (after the input)
generated_tokens = outputs[0][input_len:]
generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
# Strip the thinking block before returning
generated_text = self._strip_thinking_block(generated_text)
return generated_text
class EmbeddingModel:
"""Wrapper for PubMedBERT embedding model"""
def __init__(self):
logger.info(f"Loading embedding model: {config.EMBEDDING_MODEL_ID}")
self.model = SentenceTransformer(
config.EMBEDDING_MODEL_ID,
device=config.EMBEDDING_DEVICE
)
logger.info("Embedding model loaded successfully")
def encode(self, texts: List[str]) -> List[List[float]]:
"""
Encode texts to embeddings
Args:
texts: List of text strings to encode
Returns:
List of embedding vectors
"""
embeddings = self.model.encode(
texts,
convert_to_tensor=False,
show_progress_bar=False
)
return embeddings.tolist()
def encode_single(self, text: str) -> List[float]:
"""
Encode a single text to embedding
Args:
text: Text string to encode
Returns:
Embedding vector
"""
return self.encode([text])[0]
# Global model instances (loaded once)
_medgemma_instance = None
_embedding_instance = None
def get_medgemma() -> MedGemmaGenerator:
"""Get or create MedGemma model instance"""
global _medgemma_instance
if _medgemma_instance is None:
_medgemma_instance = MedGemmaGenerator()
return _medgemma_instance
def get_embedding_model() -> EmbeddingModel:
"""Get or create embedding model instance"""
global _embedding_instance
if _embedding_instance is None:
_embedding_instance = EmbeddingModel()
return _embedding_instance
|