File size: 9,548 Bytes
08a811f de81512 08a811f ad9e267 08a811f 674163f de81512 9004314 674163f 9004314 674163f de81512 674163f de81512 08a811f 18c0556 08a811f 18c0556 08a811f 936bc6b 08a811f 936bc6b 08a811f 18c0556 08a811f 18c0556 9004314 18c0556 9004314 ad9e267 9004314 4e19176 9004314 18c0556 ba2715f 18c0556 08a811f 18c0556 08a811f 18c0556 08a811f 18c0556 08a811f ad9e267 08a811f 18c0556 08a811f ad9e267 08a811f 9004314 18c0556 08a811f 936bc6b 18c0556 936bc6b 08a811f de5f46b 08a811f 674163f 18c0556 674163f ba2715f 674163f ba2715f 674163f ba2715f 674163f de5f46b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
import logging
import os
from functools import lru_cache
from typing import Any, Callable, Dict, Literal, Optional
from .config import get_settings
# ── ZeroGPU registration (HF Spaces only) ────────────────────────────────────
# Calling spaces.GPU() at import time establishes the connection with the
# ZeroGPU daemon. The actual @spaces.GPU decorator for the pipeline lives in
# app.py and wraps the *entire* multi-agent run so all agents share one GPU
# session (avoids lru_cache + freed-CUDA-memory hangs between agent calls).
if os.environ.get("SPACE_ID"):
try:
import spaces as _spaces
try:
_spaces.GPU(duration=600) # register with ZeroGPU daemon at startup
except TypeError:
pass # older spaces without duration param — registration still happens
except ImportError:
pass
logger = logging.getLogger(__name__)
TextModelName = Literal["medgemma_4b", "medgemma_27b", "txgemma_9b", "txgemma_2b"]
# MedGemma 4B IT is a vision-language model (Gemma3ForConditionalGeneration).
# It must be loaded with AutoModelForImageTextToText + AutoProcessor.
# All other models (medgemma-27b-text-it, txgemma-*) are causal LMs.
# On Kaggle T4, medgemma_27b is substituted with medgemma-4b-it (also multimodal),
# so we detect the architecture dynamically from the model config.
_MULTIMODAL_ARCHITECTURES = {"Gemma3ForConditionalGeneration"}
def _get_model_path(model_name: TextModelName) -> str:
settings = get_settings()
model_path_map: Dict[TextModelName, Optional[str]] = {
"medgemma_4b": settings.medgemma_4b_model,
"medgemma_27b": settings.medgemma_27b_model,
"txgemma_9b": settings.txgemma_9b_model,
"txgemma_2b": settings.txgemma_2b_model,
}
model_path = model_path_map[model_name]
if not model_path:
raise RuntimeError(
f"No local model path configured for {model_name}. "
f"Set MEDIC_LOCAL_*_MODEL in your environment or .env."
)
return model_path
def _get_load_kwargs() -> Dict[str, Any]:
import torch
settings = get_settings()
has_cuda = torch.cuda.is_available()
load_kwargs: Dict[str, Any] = {"device_map": "auto"}
if settings.quantization == "4bit" and has_cuda:
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
elif not has_cuda:
logger.warning("No CUDA GPU detected — loading model in float32 on CPU (inference will be slow)")
return load_kwargs
@lru_cache(maxsize=8)
def _get_local_multimodal(model_name: TextModelName):
"""Load a multimodal model (e.g. MedGemma 4B IT) and return a text generation callable."""
from transformers import AutoModelForImageTextToText, AutoProcessor
import torch
model_path = _get_model_path(model_name)
load_kwargs = _get_load_kwargs()
logger.info(f"Loading multimodal model: {model_path} with kwargs: {load_kwargs}")
processor = AutoProcessor.from_pretrained(model_path)
logger.info(f"Processor loaded for {model_path}")
model = AutoModelForImageTextToText.from_pretrained(model_path, **load_kwargs)
logger.info(f"Model loaded successfully: {model_path}")
def _call(
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.2,
image=None, # optional PIL.Image.Image for vision-language inference
**generate_kwargs: Any,
) -> str:
# Build chat content; prepend image token when an image is provided
content = []
if image is not None:
content.append({"type": "image", "image": image})
content.append({"type": "text", "text": prompt})
messages = [{"role": "user", "content": content}]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt",
).to(model.device)
do_sample = temperature > 0
with torch.no_grad():
output_ids = model.generate(
**inputs,
do_sample=do_sample,
temperature=temperature if do_sample else None,
max_new_tokens=max_new_tokens,
**generate_kwargs,
)
# Decode only the newly generated tokens
generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
return processor.decode(generated_ids, skip_special_tokens=True).strip()
return _call
@lru_cache(maxsize=8)
def _get_local_causal_lm(model_name: TextModelName):
"""Load a causal LM (e.g. TxGemma, MedGemma 27B text) and return a generation callable."""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_path = _get_model_path(model_name)
load_kwargs = _get_load_kwargs()
logger.info(f"Loading causal LM: {model_path} with kwargs: {load_kwargs}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info(f"Tokenizer loaded for {model_path}")
model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
logger.info(f"Model loaded successfully: {model_path}")
def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
inputs = {k: v.to(model.device) for k, v in tokenizer(prompt, return_tensors="pt").items()}
do_sample = temperature > 0
with torch.no_grad():
output_ids = model.generate(
**inputs,
do_sample=do_sample,
temperature=temperature if do_sample else None,
max_new_tokens=max_new_tokens,
**generate_kwargs,
)
# Decode only the newly generated tokens, not the input prompt
generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return _call
@lru_cache(maxsize=8)
def _is_multimodal(model_path: str) -> bool:
"""Check if a model uses a multimodal architecture by inspecting its config."""
from transformers import AutoConfig
try:
config = AutoConfig.from_pretrained(model_path)
architectures = getattr(config, "architectures", []) or []
return bool(set(architectures) & _MULTIMODAL_ARCHITECTURES)
except Exception:
return False
@lru_cache(maxsize=32)
def get_text_model(
model_name: TextModelName = "medgemma_4b",
) -> Callable[..., str]:
"""Return a cached callable for the requested model."""
model_path = _get_model_path(model_name)
if _is_multimodal(model_path):
return _get_local_multimodal(model_name)
return _get_local_causal_lm(model_name)
def _inference_core(
prompt: str,
model_name: TextModelName = "medgemma_4b",
max_new_tokens: int = 512,
temperature: float = 0.2,
**kwargs: Any,
) -> str:
"""Core text inference — no GPU decorator, runs on whatever device is available."""
model = get_text_model(model_name=model_name)
logger.info(f"Model {model_name} ready")
result = model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
logger.info(f"Inference complete, response length: {len(result)} chars")
return result
def _inference_with_image_core(
prompt: str,
image: Any,
model_name: TextModelName = "medgemma_4b",
max_new_tokens: int = 1024,
temperature: float = 0.1,
**kwargs: Any,
) -> str:
"""Core vision inference — no GPU decorator, runs on whatever device is available."""
model_path = _get_model_path(model_name)
if not _is_multimodal(model_path):
logger.warning(
f"{model_name} ({model_path}) is not a multimodal model; "
"falling back to text-only inference."
)
return _inference_core(prompt, model_name, max_new_tokens, temperature, **kwargs)
model_fn = _get_local_multimodal(model_name)
result = model_fn(
prompt, max_new_tokens=max_new_tokens, temperature=temperature, image=image, **kwargs
)
logger.info(f"Vision inference complete, response length: {len(result)} chars")
return result
def run_inference(
prompt: str,
model_name: TextModelName = "medgemma_4b",
max_new_tokens: int = 512,
temperature: float = 0.2,
**kwargs: Any,
) -> str:
"""Run inference with the specified model.
Must be called from within an active @spaces.GPU context (e.g. the
pipeline wrapper in app.py). All agents share one GPU session so that
the lru_cache'd model stays valid across the full pipeline.
"""
logger.info(f"Running inference with {model_name}, max_tokens={max_new_tokens}, temp={temperature}")
return _inference_core(prompt, model_name, max_new_tokens, temperature, **kwargs)
def run_inference_with_image(
prompt: str,
image: Any, # PIL.Image.Image
model_name: TextModelName = "medgemma_4b",
max_new_tokens: int = 1024,
temperature: float = 0.1,
**kwargs: Any,
) -> str:
"""Run vision-language inference passing a PIL image alongside the text prompt.
Falls back to text-only if the resolved model is not multimodal.
Must be called from within an active @spaces.GPU context.
"""
logger.info(f"Running vision inference with {model_name}, max_tokens={max_new_tokens}")
return _inference_with_image_core(prompt, image, model_name, max_new_tokens, temperature, **kwargs)
|