File size: 9,548 Bytes
08a811f
 
de81512
08a811f
ad9e267
08a811f
 
 
674163f
 
 
 
 
de81512
 
 
9004314
674163f
9004314
674163f
de81512
674163f
de81512
08a811f
 
 
 
18c0556
 
 
 
 
 
08a811f
 
18c0556
08a811f
 
936bc6b
 
 
 
08a811f
 
 
 
 
936bc6b
08a811f
18c0556
08a811f
18c0556
 
9004314
18c0556
9004314
ad9e267
9004314
4e19176
 
9004314
 
18c0556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2715f
 
 
 
 
 
 
 
 
 
 
 
 
 
18c0556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08a811f
18c0556
 
 
 
 
 
 
 
 
 
 
08a811f
18c0556
08a811f
18c0556
08a811f
ad9e267
 
08a811f
 
 
 
 
18c0556
08a811f
 
 
ad9e267
 
 
08a811f
 
 
 
9004314
18c0556
 
 
 
 
 
 
 
 
 
 
08a811f
 
 
 
936bc6b
18c0556
 
 
936bc6b
08a811f
 
de5f46b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08a811f
 
 
 
 
 
 
674163f
 
 
 
 
 
18c0556
674163f
ba2715f
 
 
 
 
 
 
 
 
 
674163f
ba2715f
674163f
 
ba2715f
 
674163f
de5f46b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

import logging
import os
from functools import lru_cache
from typing import Any, Callable, Dict, Literal, Optional

from .config import get_settings

# ── ZeroGPU registration (HF Spaces only) ────────────────────────────────────
# Calling spaces.GPU() at import time establishes the connection with the
# ZeroGPU daemon.  The actual @spaces.GPU decorator for the pipeline lives in
# app.py and wraps the *entire* multi-agent run so all agents share one GPU
# session (avoids lru_cache + freed-CUDA-memory hangs between agent calls).
if os.environ.get("SPACE_ID"):
    try:
        import spaces as _spaces
        try:
            _spaces.GPU(duration=600)  # register with ZeroGPU daemon at startup
        except TypeError:
            pass  # older spaces without duration param — registration still happens
    except ImportError:
        pass

logger = logging.getLogger(__name__)

TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]

# MedGemma 4B IT is a vision-language model (Gemma3ForConditionalGeneration).
# It must be loaded with AutoModelForImageTextToText + AutoProcessor.
# All other models (medgemma-27b-text-it, txgemma-*) are causal LMs.
# On Kaggle T4, medgemma_27b is substituted with medgemma-4b-it (also multimodal),
# so we detect the architecture dynamically from the model config.
_MULTIMODAL_ARCHITECTURES = {"Gemma3ForConditionalGeneration"}


def _get_model_path(model_name: TextModelName) -> str:
    settings = get_settings()
    model_path_map: Dict[TextModelName, Optional[str]] = {
        "medgemma_4b": settings.medgemma_4b_model,
        "medgemma_27b": settings.medgemma_27b_model,
        "txgemma_9b": settings.txgemma_9b_model,
        "txgemma_2b": settings.txgemma_2b_model,
    }
    model_path = model_path_map[model_name]
    if not model_path:
        raise RuntimeError(
            f"No local model path configured for {model_name}. "
            f"Set MEDIC_LOCAL_*_MODEL in your environment or .env."
        )
    return model_path


def _get_load_kwargs() -> Dict[str, Any]:
    import torch
    settings = get_settings()
    has_cuda = torch.cuda.is_available()
    load_kwargs: Dict[str, Any] = {"device_map": "auto"}
    if settings.quantization == "4bit" and has_cuda:
        from transformers import BitsAndBytesConfig
        load_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
    elif not has_cuda:
        logger.warning("No CUDA GPU detected — loading model in float32 on CPU (inference will be slow)")
    return load_kwargs


@lru_cache(maxsize=8)
def _get_local_multimodal(model_name: TextModelName):
    """Load a multimodal model (e.g. MedGemma 4B IT) and return a text generation callable."""
    from transformers import AutoModelForImageTextToText, AutoProcessor
    import torch

    model_path = _get_model_path(model_name)
    load_kwargs = _get_load_kwargs()

    logger.info(f"Loading multimodal model: {model_path} with kwargs: {load_kwargs}")
    processor = AutoProcessor.from_pretrained(model_path)
    logger.info(f"Processor loaded for {model_path}")
    model = AutoModelForImageTextToText.from_pretrained(model_path, **load_kwargs)
    logger.info(f"Model loaded successfully: {model_path}")

    def _call(
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.2,
        image=None,  # optional PIL.Image.Image for vision-language inference
        **generate_kwargs: Any,
    ) -> str:
        # Build chat content; prepend image token when an image is provided
        content = []
        if image is not None:
            content.append({"type": "image", "image": image})
        content.append({"type": "text", "text": prompt})
        messages = [{"role": "user", "content": content}]

        inputs = processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt",
        ).to(model.device)

        do_sample = temperature > 0
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                do_sample=do_sample,
                temperature=temperature if do_sample else None,
                max_new_tokens=max_new_tokens,
                **generate_kwargs,
            )
        # Decode only the newly generated tokens
        generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
        return processor.decode(generated_ids, skip_special_tokens=True).strip()

    return _call


@lru_cache(maxsize=8)
def _get_local_causal_lm(model_name: TextModelName):
    """Load a causal LM (e.g. TxGemma, MedGemma 27B text) and return a generation callable."""
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch

    model_path = _get_model_path(model_name)
    load_kwargs = _get_load_kwargs()

    logger.info(f"Loading causal LM: {model_path} with kwargs: {load_kwargs}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    logger.info(f"Tokenizer loaded for {model_path}")
    model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
    logger.info(f"Model loaded successfully: {model_path}")

    def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
        inputs = {k: v.to(model.device) for k, v in tokenizer(prompt, return_tensors="pt").items()}
        do_sample = temperature > 0
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                do_sample=do_sample,
                temperature=temperature if do_sample else None,
                max_new_tokens=max_new_tokens,
                **generate_kwargs,
            )
        # Decode only the newly generated tokens, not the input prompt
        generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
        return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    return _call


@lru_cache(maxsize=8)
def _is_multimodal(model_path: str) -> bool:
    """Check if a model uses a multimodal architecture by inspecting its config."""
    from transformers import AutoConfig
    try:
        config = AutoConfig.from_pretrained(model_path)
        architectures = getattr(config, "architectures", []) or []
        return bool(set(architectures) & _MULTIMODAL_ARCHITECTURES)
    except Exception:
        return False


@lru_cache(maxsize=32)
def get_text_model(
    model_name: TextModelName = "medgemma_4b",
) -> Callable[..., str]:
    """Return a cached callable for the requested model."""
    model_path = _get_model_path(model_name)
    if _is_multimodal(model_path):
        return _get_local_multimodal(model_name)
    return _get_local_causal_lm(model_name)



def _inference_core(
    prompt: str,
    model_name: TextModelName = "medgemma_4b",
    max_new_tokens: int = 512,
    temperature: float = 0.2,
    **kwargs: Any,
) -> str:
    """Core text inference — no GPU decorator, runs on whatever device is available."""
    model = get_text_model(model_name=model_name)
    logger.info(f"Model {model_name} ready")
    result = model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
    logger.info(f"Inference complete, response length: {len(result)} chars")
    return result


def _inference_with_image_core(
    prompt: str,
    image: Any,
    model_name: TextModelName = "medgemma_4b",
    max_new_tokens: int = 1024,
    temperature: float = 0.1,
    **kwargs: Any,
) -> str:
    """Core vision inference — no GPU decorator, runs on whatever device is available."""
    model_path = _get_model_path(model_name)
    if not _is_multimodal(model_path):
        logger.warning(
            f"{model_name} ({model_path}) is not a multimodal model; "
            "falling back to text-only inference."
        )
        return _inference_core(prompt, model_name, max_new_tokens, temperature, **kwargs)
    model_fn = _get_local_multimodal(model_name)
    result = model_fn(
        prompt, max_new_tokens=max_new_tokens, temperature=temperature, image=image, **kwargs
    )
    logger.info(f"Vision inference complete, response length: {len(result)} chars")
    return result


def run_inference(
    prompt: str,
    model_name: TextModelName = "medgemma_4b",
    max_new_tokens: int = 512,
    temperature: float = 0.2,
    **kwargs: Any,
) -> str:
    """Run inference with the specified model.

    Must be called from within an active @spaces.GPU context (e.g. the
    pipeline wrapper in app.py).  All agents share one GPU session so that
    the lru_cache'd model stays valid across the full pipeline.
    """
    logger.info(f"Running inference with {model_name}, max_tokens={max_new_tokens}, temp={temperature}")
    return _inference_core(prompt, model_name, max_new_tokens, temperature, **kwargs)


def run_inference_with_image(
    prompt: str,
    image: Any,  # PIL.Image.Image
    model_name: TextModelName = "medgemma_4b",
    max_new_tokens: int = 1024,
    temperature: float = 0.1,
    **kwargs: Any,
) -> str:
    """Run vision-language inference passing a PIL image alongside the text prompt.

    Falls back to text-only if the resolved model is not multimodal.
    Must be called from within an active @spaces.GPU context.
    """
    logger.info(f"Running vision inference with {model_name}, max_tokens={max_new_tokens}")
    return _inference_with_image_core(prompt, image, model_name, max_new_tokens, temperature, **kwargs)