""" MedGemma client: unified interface for 4B (multimodal) and 27B (text-only) models. Loads locally via transformers with optional 4-bit quantization. """ from __future__ import annotations import logging import os import threading from PIL import Image from config import ( USE_27B, QUANTIZE_4B, HF_TOKEN, DEVICE, MEDGEMMA_4B_MODEL_ID, MEDGEMMA_27B_MODEL_ID, MAX_NEW_TOKENS_4B, MAX_NEW_TOKENS_27B, TEMPERATURE, REPETITION_PENALTY, ENABLE_TORCH_COMPILE, ENABLE_SDPA, ) from models.utils import strip_thinking_tokens, resize_for_medgemma, apply_prompt_repetition logger = logging.getLogger(__name__) _model_4b = None _processor_4b = None _model_27b = None _tokenizer_27b = None _load_4b_lock = threading.Lock() _load_27b_lock = threading.Lock() def _is_local_path(model_id: str) -> bool: """Check if model_id is a local directory path.""" return os.path.isdir(model_id) def _token_arg(model_id: str) -> dict: """Return token kwarg only when loading from HF Hub (not local path).""" if _is_local_path(model_id): return {} # Only pass `token` when explicitly provided; omitting it lets HF Hub fall back # to `huggingface-cli login` cached credentials (useful on local/dev machines). if HF_TOKEN: return {"token": HF_TOKEN} return {} def _get_quantization_config(): """Return BitsAndBytesConfig for 4-bit quantization.""" import torch from transformers import BitsAndBytesConfig return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", ) def load_4b(): """Load MedGemma 4B-IT (multimodal) model and processor.""" global _model_4b, _processor_4b if _model_4b is not None: return _model_4b, _processor_4b with _load_4b_lock: if _model_4b is not None: return _model_4b, _processor_4b import torch from transformers import AutoModelForImageTextToText, AutoProcessor is_local = _is_local_path(MEDGEMMA_4B_MODEL_ID) opts = [] if QUANTIZE_4B: opts.append("4-bit") else: opts.append("bf16") if ENABLE_SDPA: opts.append("SDPA") if ENABLE_TORCH_COMPILE: opts.append("compiled") logger.info( "Loading MedGemma 4B-IT (%s) from %s...", "+".join(opts), "local" if is_local else "HF Hub", ) # BitsAndBytes quantization requires device_map="auto", not "cuda" device_map = "auto" if QUANTIZE_4B else DEVICE kwargs = {**_token_arg(MEDGEMMA_4B_MODEL_ID), "device_map": device_map} if QUANTIZE_4B: kwargs["quantization_config"] = _get_quantization_config() else: kwargs["dtype"] = torch.bfloat16 # SDPA: 优化注意力计算 if ENABLE_SDPA: kwargs["attn_implementation"] = "sdpa" _processor_4b = AutoProcessor.from_pretrained(MEDGEMMA_4B_MODEL_ID, **_token_arg(MEDGEMMA_4B_MODEL_ID)) _model_4b = AutoModelForImageTextToText.from_pretrained(MEDGEMMA_4B_MODEL_ID, **kwargs) _model_4b.eval() # torch.compile: JIT 编译加速(首次推理会编译,耐心等待) if ENABLE_TORCH_COMPILE: logger.info("Compiling model with torch.compile (first inference will be slow)...") _model_4b = torch.compile(_model_4b, mode="reduce-overhead") logger.info("MedGemma 4B loaded.") return _model_4b, _processor_4b def load_27b(): """Load MedGemma 27B Text-IT model and tokenizer (A100 only).""" global _model_27b, _tokenizer_27b if _model_27b is not None: return _model_27b, _tokenizer_27b with _load_27b_lock: if _model_27b is not None: return _model_27b, _tokenizer_27b import torch from transformers import AutoModelForCausalLM, AutoTokenizer is_local = _is_local_path(MEDGEMMA_27B_MODEL_ID) opts = ["bf16"] if ENABLE_SDPA: opts.append("SDPA") if ENABLE_TORCH_COMPILE: opts.append("compiled") logger.info( "Loading MedGemma 27B Text-IT (%s) from %s...", "+".join(opts), "local" if is_local else "HF Hub", ) kwargs = { **_token_arg(MEDGEMMA_27B_MODEL_ID), "torch_dtype": torch.bfloat16, "device_map": "auto", } if ENABLE_SDPA: kwargs["attn_implementation"] = "sdpa" _tokenizer_27b = AutoTokenizer.from_pretrained(MEDGEMMA_27B_MODEL_ID, **_token_arg(MEDGEMMA_27B_MODEL_ID)) _model_27b = AutoModelForCausalLM.from_pretrained(MEDGEMMA_27B_MODEL_ID, **kwargs) _model_27b.eval() if ENABLE_TORCH_COMPILE: logger.info("Compiling model with torch.compile (first inference will be slow)...") _model_27b = torch.compile(_model_27b, mode="reduce-overhead") logger.info("MedGemma 27B loaded.") return _model_27b, _tokenizer_27b def generate_with_image(prompt: str, image: Image.Image, system_prompt: str = "") -> str: """Generate text from image + text prompt using MedGemma 4B.""" model, processor = load_4b() image = resize_for_medgemma(image) prompt = apply_prompt_repetition(prompt) messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.append({ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], }) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) import torch with torch.inference_mode(): output_ids = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS_4B, do_sample=TEMPERATURE > 0, repetition_penalty=REPETITION_PENALTY, **({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}), ) # Decode only the new tokens new_tokens = output_ids[0, inputs["input_ids"].shape[1]:] text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True) return strip_thinking_tokens(text) def generate_text(prompt: str, system_prompt: str = "") -> str: """Generate text from text-only prompt. Uses 27B if available, else 4B.""" if USE_27B: return _generate_text_27b(prompt, system_prompt) return _generate_text_4b(prompt, system_prompt) def _generate_text_4b(prompt: str, system_prompt: str = "") -> str: """Text-only generation with 4B model.""" model, processor = load_4b() prompt = apply_prompt_repetition(prompt) messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) import torch with torch.inference_mode(): output_ids = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS_4B, do_sample=TEMPERATURE > 0, repetition_penalty=REPETITION_PENALTY, **({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}), ) new_tokens = output_ids[0, inputs["input_ids"].shape[1]:] text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True) return strip_thinking_tokens(text) def _generate_text_27b(prompt: str, system_prompt: str = "") -> str: """Text-only generation with 27B model (thinking mode).""" model, tokenizer = load_27b() prompt = apply_prompt_repetition(prompt) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) inputs = tokenizer(input_text, return_tensors="pt").to(model.device) import torch with torch.inference_mode(): output_ids = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS_27B, do_sample=TEMPERATURE > 0, repetition_penalty=REPETITION_PENALTY, **({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}), ) new_tokens = output_ids[0, inputs["input_ids"].shape[1]:] text = tokenizer.decode(new_tokens, skip_special_tokens=True) return strip_thinking_tokens(text)