Spaces:
Running on Zero
Running on Zero
File size: 7,136 Bytes
9e0c9cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """
Model handler for MedGemma 1.5 inference.
"""
import os
import torch
from PIL import Image
from typing import List, Optional
from dotenv import load_dotenv
from transformers import AutoProcessor, AutoModelForImageTextToText
# Load environment variables from .env file
load_dotenv()
def check_gpu_availability():
"""Check GPU availability and print diagnostics."""
print("=" * 60)
print("GPU Availability Check")
print("=" * 60)
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")
if cuda_available:
device_count = torch.cuda.device_count()
print(f"Number of GPUs: {device_count}")
for i in range(device_count):
device_name = torch.cuda.get_device_name(i)
print(f" GPU {i}: {device_name}")
print(f"Current GPU: {torch.cuda.current_device()}")
# Check for MIG (Multi-Instance GPU)
gpu_name = torch.cuda.get_device_name(0)
if "MIG" in gpu_name:
print("Note: Running on MIG partition - using float32 for compatibility")
else:
print("CUDA is not available. Model will use CPU (slow).")
print("\nTo use GPU, ensure you have:")
print("1. NVIDIA GPU with CUDA support")
print("2. CUDA toolkit installed")
print("3. PyTorch with CUDA support: pip install torch --index-url https://download.pytorch.org/whl/cu118")
print("=" * 60)
return cuda_available
class MedGemmaHandler:
"""Handler for MedGemma 1.5 model inference."""
def __init__(self, model_id: str = "google/medgemma-1.5-4b-it", device: Optional[str] = None):
self.model_id = model_id
self.device = device
self.processor = None
self.model = None
self.use_float32 = False # Flag for MIG compatibility
# Check for local model path (useful for local development)
local_model_path = os.path.join(os.path.dirname(__file__), "models", "medgemma-1.5-4b-it")
if os.path.exists(local_model_path) and os.path.isfile(os.path.join(local_model_path, "config.json")):
self.model_id = local_model_path
print(f"Using local model from: {local_model_path}")
else:
print(f"Using model from Hugging Face Hub: {self.model_id}")
def load_model(self):
"""Load the MedGemma 1.5 model and processor."""
print(f"Loading MedGemma model: {self.model_id}")
# Check GPU availability
cuda_available = check_gpu_availability()
# Determine device
if self.device is None:
if cuda_available:
self.device = "cuda"
gpu_name = torch.cuda.get_device_name(0)
print(f"Using GPU: {gpu_name}")
# Check for MIG partition - use float32 for compatibility
if "MIG" in gpu_name:
self.use_float32 = True
print("MIG detected: Using float32 for CUBLAS compatibility")
else:
self.device = "cpu"
self.use_float32 = True
print("WARNING: Using CPU - this will be very slow!")
else:
print(f"Using device: {self.device}")
# Get HF token from environment
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("Using Hugging Face token from .env file")
else:
print("Warning: No HF_TOKEN found in .env file")
self.processor = AutoProcessor.from_pretrained(self.model_id, token=hf_token)
# Load model with proper device configuration
# Use attn_implementation="eager" on MIG to avoid SDPA CUBLAS issues
if self.device == "cuda" and torch.cuda.is_available():
if self.use_float32:
print("Loading model on GPU with float32 + eager attention (MIG compatibility)...")
self.model = AutoModelForImageTextToText.from_pretrained(
self.model_id,
torch_dtype=torch.float32,
device_map="cuda",
token=hf_token,
attn_implementation="eager", # Disable SDPA for MIG compatibility
)
else:
print("Loading model on GPU with bfloat16...")
self.model = AutoModelForImageTextToText.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16,
device_map="cuda",
token=hf_token,
)
else:
print("Loading model on CPU (this may take a while)...")
self.model = AutoModelForImageTextToText.from_pretrained(
self.model_id,
torch_dtype=torch.float32,
device_map="cpu",
token=hf_token,
)
print(f"Model loaded on device: {next(self.model.parameters()).device}")
print(f"Model dtype: {next(self.model.parameters()).dtype}")
print("Model loaded successfully!")
def generate_report(
self,
images: List[Image.Image],
prompt: str,
max_new_tokens: int = 350,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
do_sample: bool = True,
) -> str:
"""Generate a radiology report from medical images."""
if self.model is None or self.processor is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
print(f"Processing {len(images)} images...")
content = [{"type": "image", "image": img} for img in images]
content.append({"type": "text", "text": prompt})
messages = [
{
"role": "user",
"content": content
}
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
# Move to device - let the model handle dtype conversion
inputs = inputs.to(self.model.device)
input_len = inputs["input_ids"].shape[-1]
print(f"Input sequence length: {input_len}")
with torch.inference_mode():
if do_sample and temperature > 0:
generation = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
else:
generation = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
generation = generation[0][input_len:]
report = self.processor.decode(generation, skip_special_tokens=True)
# Clear GPU cache after inference
if self.device == "cuda":
torch.cuda.empty_cache()
print("GPU cache cleared.")
return report
|