""" Model loading and inference utilities """ import torch from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from typing import List, Dict import logging from . import config logger = logging.getLogger(__name__) import os from huggingface_hub import login token = os.getenv("HF_TOKEN") if token: login(token=token) logger.info("HuggingFace login successful") else: logger.warning("HF_TOKEN not found — model download will fail if MedGemma is gated") import re from transformers import AutoProcessor, AutoModelForImageTextToText from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_4bit=True) class MedGemmaGenerator: """Wrapper for MedGemma 1.5 4B model""" def __init__(self): logger.info(f"Loading MedGemma model: {config.MEDGEMMA_MODEL_ID}") self.processor = AutoProcessor.from_pretrained(config.MEDGEMMA_MODEL_ID, token=os.getenv("HF_TOKEN"),) self.model = AutoModelForImageTextToText.from_pretrained( config.MEDGEMMA_MODEL_ID, quantization_config=quantization_config, low_cpu_mem_usage=True, token=os.getenv("HF_TOKEN"), ) self.model = self.model.to("cpu") self.model.eval() logger.info("MedGemma model loaded successfully") def _strip_thinking_block(self, text: str) -> str: """ Remove the thinking/reasoning block that Gemma 3-based models emit. MedGemma 1.5 uses thought... tokens. """ text = re.sub( r"thought[\s\S]*?", "", text, flags=re.IGNORECASE, ) text = re.sub( r"[\s\S]*?", "", text, flags=re.IGNORECASE, ) text = re.sub( r"thought[\s\S]*$", "", text, flags=re.IGNORECASE, ) text = re.sub( r"[\s\S]*$", "", text, flags=re.IGNORECASE, ) return text.strip() def generate(self, prompt: str, max_new_tokens: int = None) -> str: """ Generate text from prompt using MedGemma Args: prompt: Input prompt max_new_tokens: Override default max tokens if provided Returns: Generated text (with thinking block removed) """ gen_config = config.GENERATION_CONFIG.copy() if max_new_tokens: gen_config["max_new_tokens"] = max_new_tokens # Use proper message format for MedGemma 1.5 4B-IT messages = [ { "role": "user", "content": [{"type": "text", "text": prompt}] } ] # Apply chat template properly inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to("cpu") # no dtype cast — float32 stays as-is input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = self.model.generate( **inputs, **gen_config, pad_token_id=self.processor.tokenizer.pad_token_id if hasattr(self.processor, 'tokenizer') else self.processor.pad_token_id, ) # Extract only the generated portion (after the input) generated_tokens = outputs[0][input_len:] generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True) # Strip the thinking block before returning generated_text = self._strip_thinking_block(generated_text) return generated_text class EmbeddingModel: """Wrapper for PubMedBERT embedding model""" def __init__(self): logger.info(f"Loading embedding model: {config.EMBEDDING_MODEL_ID}") self.model = SentenceTransformer( config.EMBEDDING_MODEL_ID, device=config.EMBEDDING_DEVICE ) logger.info("Embedding model loaded successfully") def encode(self, texts: List[str]) -> List[List[float]]: """ Encode texts to embeddings Args: texts: List of text strings to encode Returns: List of embedding vectors """ embeddings = self.model.encode( texts, convert_to_tensor=False, show_progress_bar=False ) return embeddings.tolist() def encode_single(self, text: str) -> List[float]: """ Encode a single text to embedding Args: text: Text string to encode Returns: Embedding vector """ return self.encode([text])[0] # Global model instances (loaded once) _medgemma_instance = None _embedding_instance = None def get_medgemma() -> MedGemmaGenerator: """Get or create MedGemma model instance""" global _medgemma_instance if _medgemma_instance is None: _medgemma_instance = MedGemmaGenerator() return _medgemma_instance def get_embedding_model() -> EmbeddingModel: """Get or create embedding model instance""" global _embedding_instance if _embedding_instance is None: _embedding_instance = EmbeddingModel() return _embedding_instance