File size: 5,680 Bytes
7529164
 
 
 
 
 
 
 
 
 
 
 
 
cc92cee
 
 
1b65fec
cc92cee
 
ebaa9cf
 
 
cc92cee
7529164
 
 
1fdf84c
 
 
 
7529164
 
 
 
 
a6956fd
ebaa9cf
7529164
 
1fdf84c
a6956fd
ebaa9cf
7529164
a6956fd
7529164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6956fd
7529164
 
 
 
 
 
a6956fd
7529164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Model loading and inference utilities
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import logging

from . import config

logger = logging.getLogger(__name__)

import os
from huggingface_hub import login

token = os.getenv("HF_TOKEN")
if token:
    login(token=token)
    logger.info("HuggingFace login successful")
else:
    logger.warning("HF_TOKEN not found — model download will fail if MedGemma is gated")

import re
from transformers import AutoProcessor, AutoModelForImageTextToText

from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

class MedGemmaGenerator:
    """Wrapper for MedGemma 1.5 4B model"""
    
    def __init__(self):
        logger.info(f"Loading MedGemma model: {config.MEDGEMMA_MODEL_ID}")

        self.processor = AutoProcessor.from_pretrained(config.MEDGEMMA_MODEL_ID, token=os.getenv("HF_TOKEN"),)
        self.model = AutoModelForImageTextToText.from_pretrained(
            config.MEDGEMMA_MODEL_ID,
            quantization_config=quantization_config,
            low_cpu_mem_usage=True,
            token=os.getenv("HF_TOKEN"),
        )
        self.model = self.model.to("cpu")
        self.model.eval()
        logger.info("MedGemma model loaded successfully")
    
    def _strip_thinking_block(self, text: str) -> str:
        """
        Remove the thinking/reasoning block that Gemma 3-based models emit.
        MedGemma 1.5 uses <unused94>thought...<unused95> tokens.
    
        """
        
        text = re.sub(
            r"<unused94>thought[\s\S]*?<unused95>",
            "",
            text,
            flags=re.IGNORECASE,
        )
        text = re.sub(
            r"<think>[\s\S]*?</think>",
            "",
            text,
            flags=re.IGNORECASE,
        )

        text = re.sub(
            r"<unused94>thought[\s\S]*$",
            "",
            text,
            flags=re.IGNORECASE,
        )
        text = re.sub(
            r"<think>[\s\S]*$",
            "",
            text,
            flags=re.IGNORECASE,
        )

        return text.strip()
    
    def generate(self, prompt: str, max_new_tokens: int = None) -> str:
        """
        Generate text from prompt using MedGemma
        
        Args:
            prompt: Input prompt
            max_new_tokens: Override default max tokens if provided
            
        Returns:
            Generated text (with thinking block removed)
        """
        gen_config = config.GENERATION_CONFIG.copy()
        if max_new_tokens:
            gen_config["max_new_tokens"] = max_new_tokens
        
        # Use proper message format for MedGemma 1.5 4B-IT
        messages = [
            {
                "role": "user",
                "content": [{"type": "text", "text": prompt}]
            }
        ]
        
                # Apply chat template properly
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to("cpu")          # no dtype cast — float32 stays as-is
        
        input_len = inputs["input_ids"].shape[-1]
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                **gen_config,
                pad_token_id=self.processor.tokenizer.pad_token_id if hasattr(self.processor, 'tokenizer') else self.processor.pad_token_id,
            )
        
        # Extract only the generated portion (after the input)
        generated_tokens = outputs[0][input_len:]
        generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
        
        # Strip the thinking block before returning
        generated_text = self._strip_thinking_block(generated_text)
        
        return generated_text

class EmbeddingModel:
    """Wrapper for PubMedBERT embedding model"""
    
    def __init__(self):
        logger.info(f"Loading embedding model: {config.EMBEDDING_MODEL_ID}")
        
        self.model = SentenceTransformer(
            config.EMBEDDING_MODEL_ID,
            device=config.EMBEDDING_DEVICE
        )
        
        logger.info("Embedding model loaded successfully")
    
    def encode(self, texts: List[str]) -> List[List[float]]:
        """
        Encode texts to embeddings
        
        Args:
            texts: List of text strings to encode
            
        Returns:
            List of embedding vectors
        """
        embeddings = self.model.encode(
            texts,
            convert_to_tensor=False,
            show_progress_bar=False
        )
        return embeddings.tolist()
    
    def encode_single(self, text: str) -> List[float]:
        """
        Encode a single text to embedding
        
        Args:
            text: Text string to encode
            
        Returns:
            Embedding vector
        """
        return self.encode([text])[0]


# Global model instances (loaded once)
_medgemma_instance = None
_embedding_instance = None


def get_medgemma() -> MedGemmaGenerator:
    """Get or create MedGemma model instance"""
    global _medgemma_instance
    if _medgemma_instance is None:
        _medgemma_instance = MedGemmaGenerator()
    return _medgemma_instance


def get_embedding_model() -> EmbeddingModel:
    """Get or create embedding model instance"""
    global _embedding_instance
    if _embedding_instance is None:
        _embedding_instance = EmbeddingModel()
    return _embedding_instance