File size: 7,136 Bytes
9e0c9cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Model handler for MedGemma 1.5 inference.
"""
import os
import torch
from PIL import Image
from typing import List, Optional
from dotenv import load_dotenv
from transformers import AutoProcessor, AutoModelForImageTextToText

# Load environment variables from .env file
load_dotenv()


def check_gpu_availability():
    """Check GPU availability and print diagnostics."""
    print("=" * 60)
    print("GPU Availability Check")
    print("=" * 60)

    cuda_available = torch.cuda.is_available()
    print(f"CUDA available: {cuda_available}")

    if cuda_available:
        device_count = torch.cuda.device_count()
        print(f"Number of GPUs: {device_count}")
        for i in range(device_count):
            device_name = torch.cuda.get_device_name(i)
            print(f"  GPU {i}: {device_name}")
        print(f"Current GPU: {torch.cuda.current_device()}")

        # Check for MIG (Multi-Instance GPU)
        gpu_name = torch.cuda.get_device_name(0)
        if "MIG" in gpu_name:
            print("Note: Running on MIG partition - using float32 for compatibility")
    else:
        print("CUDA is not available. Model will use CPU (slow).")
        print("\nTo use GPU, ensure you have:")
        print("1. NVIDIA GPU with CUDA support")
        print("2. CUDA toolkit installed")
        print("3. PyTorch with CUDA support: pip install torch --index-url https://download.pytorch.org/whl/cu118")

    print("=" * 60)

    return cuda_available


class MedGemmaHandler:
    """Handler for MedGemma 1.5 model inference."""

    def __init__(self, model_id: str = "google/medgemma-1.5-4b-it", device: Optional[str] = None):
        self.model_id = model_id
        self.device = device
        self.processor = None
        self.model = None
        self.use_float32 = False  # Flag for MIG compatibility

        # Check for local model path (useful for local development)
        local_model_path = os.path.join(os.path.dirname(__file__), "models", "medgemma-1.5-4b-it")
        if os.path.exists(local_model_path) and os.path.isfile(os.path.join(local_model_path, "config.json")):
            self.model_id = local_model_path
            print(f"Using local model from: {local_model_path}")
        else:
            print(f"Using model from Hugging Face Hub: {self.model_id}")

    def load_model(self):
        """Load the MedGemma 1.5 model and processor."""
        print(f"Loading MedGemma model: {self.model_id}")

        # Check GPU availability
        cuda_available = check_gpu_availability()

        # Determine device
        if self.device is None:
            if cuda_available:
                self.device = "cuda"
                gpu_name = torch.cuda.get_device_name(0)
                print(f"Using GPU: {gpu_name}")
                # Check for MIG partition - use float32 for compatibility
                if "MIG" in gpu_name:
                    self.use_float32 = True
                    print("MIG detected: Using float32 for CUBLAS compatibility")
            else:
                self.device = "cpu"
                self.use_float32 = True
                print("WARNING: Using CPU - this will be very slow!")
        else:
            print(f"Using device: {self.device}")

        # Get HF token from environment
        hf_token = os.getenv("HF_TOKEN")
        if hf_token:
            print("Using Hugging Face token from .env file")
        else:
            print("Warning: No HF_TOKEN found in .env file")

        self.processor = AutoProcessor.from_pretrained(self.model_id, token=hf_token)

        # Load model with proper device configuration
        # Use attn_implementation="eager" on MIG to avoid SDPA CUBLAS issues
        if self.device == "cuda" and torch.cuda.is_available():
            if self.use_float32:
                print("Loading model on GPU with float32 + eager attention (MIG compatibility)...")
                self.model = AutoModelForImageTextToText.from_pretrained(
                    self.model_id,
                    torch_dtype=torch.float32,
                    device_map="cuda",
                    token=hf_token,
                    attn_implementation="eager",  # Disable SDPA for MIG compatibility
                )
            else:
                print("Loading model on GPU with bfloat16...")
                self.model = AutoModelForImageTextToText.from_pretrained(
                    self.model_id,
                    torch_dtype=torch.bfloat16,
                    device_map="cuda",
                    token=hf_token,
                )
        else:
            print("Loading model on CPU (this may take a while)...")
            self.model = AutoModelForImageTextToText.from_pretrained(
                self.model_id,
                torch_dtype=torch.float32,
                device_map="cpu",
                token=hf_token,
            )

        print(f"Model loaded on device: {next(self.model.parameters()).device}")
        print(f"Model dtype: {next(self.model.parameters()).dtype}")
        print("Model loaded successfully!")

    def generate_report(
        self,
        images: List[Image.Image],
        prompt: str,
        max_new_tokens: int = 350,
        temperature: float = 0.7,
        top_p: float = 0.9,
        top_k: int = 50,
        do_sample: bool = True,
    ) -> str:
        """Generate a radiology report from medical images."""
        if self.model is None or self.processor is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        print(f"Processing {len(images)} images...")

        content = [{"type": "image", "image": img} for img in images]
        content.append({"type": "text", "text": prompt})

        messages = [
            {
                "role": "user",
                "content": content
            }
        ]

        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        )

        # Move to device - let the model handle dtype conversion
        inputs = inputs.to(self.model.device)

        input_len = inputs["input_ids"].shape[-1]
        print(f"Input sequence length: {input_len}")

        with torch.inference_mode():
            if do_sample and temperature > 0:
                generation = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                )
            else:
                generation = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                )
            generation = generation[0][input_len:]

        report = self.processor.decode(generation, skip_special_tokens=True)

        # Clear GPU cache after inference
        if self.device == "cuda":
            torch.cuda.empty_cache()
            print("GPU cache cleared.")

        return report