diagnostic-devils-advocate / models /medgemma_client.py
yipengsun's picture
Upload models/medgemma_client.py with huggingface_hub
700aa8b verified
"""
MedGemma client: unified interface for 4B (multimodal) and 27B (text-only) models.
Loads locally via transformers with optional 4-bit quantization.
"""
from __future__ import annotations
import logging
import os
import threading
from PIL import Image
from config import (
USE_27B, QUANTIZE_4B, HF_TOKEN, DEVICE,
MEDGEMMA_4B_MODEL_ID, MEDGEMMA_27B_MODEL_ID,
MAX_NEW_TOKENS_4B, MAX_NEW_TOKENS_27B, TEMPERATURE, REPETITION_PENALTY,
ENABLE_TORCH_COMPILE, ENABLE_SDPA,
)
from models.utils import strip_thinking_tokens, resize_for_medgemma, apply_prompt_repetition
logger = logging.getLogger(__name__)
_model_4b = None
_processor_4b = None
_model_27b = None
_tokenizer_27b = None
_load_4b_lock = threading.Lock()
_load_27b_lock = threading.Lock()
def _is_local_path(model_id: str) -> bool:
"""Check if model_id is a local directory path."""
return os.path.isdir(model_id)
def _token_arg(model_id: str) -> dict:
"""Return token kwarg only when loading from HF Hub (not local path)."""
if _is_local_path(model_id):
return {}
# Only pass `token` when explicitly provided; omitting it lets HF Hub fall back
# to `huggingface-cli login` cached credentials (useful on local/dev machines).
if HF_TOKEN:
return {"token": HF_TOKEN}
return {}
def _get_quantization_config():
"""Return BitsAndBytesConfig for 4-bit quantization."""
import torch
from transformers import BitsAndBytesConfig
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
def load_4b():
"""Load MedGemma 4B-IT (multimodal) model and processor."""
global _model_4b, _processor_4b
if _model_4b is not None:
return _model_4b, _processor_4b
with _load_4b_lock:
if _model_4b is not None:
return _model_4b, _processor_4b
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
is_local = _is_local_path(MEDGEMMA_4B_MODEL_ID)
opts = []
if QUANTIZE_4B:
opts.append("4-bit")
else:
opts.append("bf16")
if ENABLE_SDPA:
opts.append("SDPA")
if ENABLE_TORCH_COMPILE:
opts.append("compiled")
logger.info(
"Loading MedGemma 4B-IT (%s) from %s...",
"+".join(opts),
"local" if is_local else "HF Hub",
)
# BitsAndBytes quantization requires device_map="auto", not "cuda"
device_map = "auto" if QUANTIZE_4B else DEVICE
kwargs = {**_token_arg(MEDGEMMA_4B_MODEL_ID), "device_map": device_map}
if QUANTIZE_4B:
kwargs["quantization_config"] = _get_quantization_config()
else:
kwargs["dtype"] = torch.bfloat16
# SDPA: 优化注意力计算
if ENABLE_SDPA:
kwargs["attn_implementation"] = "sdpa"
_processor_4b = AutoProcessor.from_pretrained(MEDGEMMA_4B_MODEL_ID, **_token_arg(MEDGEMMA_4B_MODEL_ID))
_model_4b = AutoModelForImageTextToText.from_pretrained(MEDGEMMA_4B_MODEL_ID, **kwargs)
_model_4b.eval()
# torch.compile: JIT 编译加速(首次推理会编译,耐心等待)
if ENABLE_TORCH_COMPILE:
logger.info("Compiling model with torch.compile (first inference will be slow)...")
_model_4b = torch.compile(_model_4b, mode="reduce-overhead")
logger.info("MedGemma 4B loaded.")
return _model_4b, _processor_4b
def load_27b():
"""Load MedGemma 27B Text-IT model and tokenizer (A100 only)."""
global _model_27b, _tokenizer_27b
if _model_27b is not None:
return _model_27b, _tokenizer_27b
with _load_27b_lock:
if _model_27b is not None:
return _model_27b, _tokenizer_27b
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
is_local = _is_local_path(MEDGEMMA_27B_MODEL_ID)
opts = ["bf16"]
if ENABLE_SDPA:
opts.append("SDPA")
if ENABLE_TORCH_COMPILE:
opts.append("compiled")
logger.info(
"Loading MedGemma 27B Text-IT (%s) from %s...",
"+".join(opts),
"local" if is_local else "HF Hub",
)
kwargs = {
**_token_arg(MEDGEMMA_27B_MODEL_ID),
"torch_dtype": torch.bfloat16,
"device_map": "auto",
}
if ENABLE_SDPA:
kwargs["attn_implementation"] = "sdpa"
_tokenizer_27b = AutoTokenizer.from_pretrained(MEDGEMMA_27B_MODEL_ID, **_token_arg(MEDGEMMA_27B_MODEL_ID))
_model_27b = AutoModelForCausalLM.from_pretrained(MEDGEMMA_27B_MODEL_ID, **kwargs)
_model_27b.eval()
if ENABLE_TORCH_COMPILE:
logger.info("Compiling model with torch.compile (first inference will be slow)...")
_model_27b = torch.compile(_model_27b, mode="reduce-overhead")
logger.info("MedGemma 27B loaded.")
return _model_27b, _tokenizer_27b
def generate_with_image(prompt: str, image: Image.Image, system_prompt: str = "") -> str:
"""Generate text from image + text prompt using MedGemma 4B."""
model, processor = load_4b()
image = resize_for_medgemma(image)
prompt = apply_prompt_repetition(prompt)
messages = []
if system_prompt:
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
messages.append({
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
})
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt",
).to(model.device)
import torch
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS_4B,
do_sample=TEMPERATURE > 0,
repetition_penalty=REPETITION_PENALTY,
**({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}),
)
# Decode only the new tokens
new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)
return strip_thinking_tokens(text)
def generate_text(prompt: str, system_prompt: str = "") -> str:
"""Generate text from text-only prompt. Uses 27B if available, else 4B."""
if USE_27B:
return _generate_text_27b(prompt, system_prompt)
return _generate_text_4b(prompt, system_prompt)
def _generate_text_4b(prompt: str, system_prompt: str = "") -> str:
"""Text-only generation with 4B model."""
model, processor = load_4b()
prompt = apply_prompt_repetition(prompt)
messages = []
if system_prompt:
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt",
).to(model.device)
import torch
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS_4B,
do_sample=TEMPERATURE > 0,
repetition_penalty=REPETITION_PENALTY,
**({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}),
)
new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)
return strip_thinking_tokens(text)
def _generate_text_27b(prompt: str, system_prompt: str = "") -> str:
"""Text-only generation with 27B model (thinking mode)."""
model, tokenizer = load_27b()
prompt = apply_prompt_repetition(prompt)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
import torch
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS_27B,
do_sample=TEMPERATURE > 0,
repetition_penalty=REPETITION_PENALTY,
**({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}),
)
new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True)
return strip_thinking_tokens(text)