File size: 17,340 Bytes
ed8368e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
import torch, torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from transformers import TorchAoConfig, Qwen2_5_VLForConditionalGeneration, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, AutoModel
from qwen_vl_utils import process_vision_info
import gc
# from transformers.image_utils import load_image

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

class VLMManager:
    """
    A manager class for Vision-Language Models that handles model loading,
    caching, and dynamic switching between different models.
    """
    
    def __init__(self, default_model: str = "Gemma3-4B"):
        """
        Initialize the VLM Manager with a default model.
        
        Args:
            default_model (str): The default model to load initially.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.current_model_name = None
        self.processor = None
        self.tokenizer = None  # Initialize tokenizer attribute
        self.model = None

        self.system_message = """
        You are an expert cultural-aware image-analysis assistant. For every image:
        1. Output exactly 40 words in total.
        2. Use a single paragraph (no lists or bullet points).
        3. Describe Who (appearance/emotion), What (action), and Where (setting).
        4. Do NOT include opinions or speculations.
        5. If you go over 40 words, shorten or remove non-essential details.
        """

        self.user_prompt = """
        Given this image, please provide an image description of around 40 words with extensive and detailed visual information. 

        Descriptions must be objective: focus on how you would describe the image to someone who can't see it, without your own opinions/speculations. 

        The text needs to include the main concept and describe the content of the image in detail by including:
        - Who?: The visual appearance and observable emotions (e.g., "is smiling") of persons and animals.
        - What?: The actions performed in the image.
        - Where?: The setting of the image, including the size, color, and relationships between objects.
        """
        
        # Load the default model
        self.load_model(default_model)
    
    def load_model(self, model_name: str):
        """
        Load a VLM model. If the model is already loaded, return the cached version.
        
        Args:
            model_name (str): The name of the model to load.
        """
        # If the requested model is already loaded, no need to reload
        if self.current_model_name == model_name and self.model is not None:
            print(f"Model {model_name} is already loaded, using cached version.")
            if self.current_model_name == "InternVL3_5-8B":
                return self.tokenizer, self.model
            else:
                return self.processor, self.model
        
        print(f"Loading model: {model_name}")
        
        # Clear current model from memory if exists
        if self.model is not None:
            del self.model
            self.model = None
            if self.current_model_name == "InternVL3_5-8B":
                if hasattr(self, 'tokenizer') and self.tokenizer is not None:
                    del self.tokenizer
                    self.tokenizer = None
            else:
                if hasattr(self, 'processor') and self.processor is not None:
                    del self.processor
                    self.processor = None
            # Force garbage collection and clear CUDA cache
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()  # Wait for all operations to complete
        
        # Load the new model
        if model_name == "SmolVLM-500M":
            self.processor, self.model = self._load_smolvlm_model("HuggingFaceTB/SmolVLM-500M-Instruct")
        elif model_name == "Qwen2.5-VL-7B":
            self.processor, self.model = self._load_qwen25_model("Qwen/Qwen2.5-VL-7B-Instruct")
        elif model_name == "InternVL3_5-8B":
            self.tokenizer, self.model = self._load_internvl35_model("OpenGVLab/InternVL3_5-8B-Instruct")
        elif model_name == "Gemma3-4B":
            self.processor, self.model = self._load_gemma3_model("google/gemma-3-4b-it")
        else:
            raise ValueError(f"Model {model_name} is not supported or not available.")
        
        self.current_model_name = model_name
        print(f"Successfully loaded model: {model_name}")

    def generate_caption(self, image):
        """
        Generate a caption for the given image using the loaded model.
        
        Args:
            processor: The processor for the model.
            model: The model to use for generating the caption.
            image: The image to generate a caption for.
        """
        if self.current_model_name == "SmolVLM-500M":
            return self._inference_smolvlm_model(image)
        elif self.current_model_name == "Qwen2.5-VL-7B":
            return self._inference_qwen25_model(image)
        elif self.current_model_name == "InternVL3_5-8B":
            return self._inference_internvl35_model(image)
        elif self.current_model_name == "Gemma3-4B":
            return self._inference_gemma3_model(image)
        else:
            raise ValueError(f"Model {self.current_model_name} is not supported or not available.")
    
    def get_current_model(self):
        """
        Get the currently loaded model and processor.
        
        Returns:
            tuple: A tuple containing (processor, model, model_name).
        """
        return self.processor, self.model, self.current_model_name
    
    def cleanup_memory(self):
        """
        Explicit memory cleanup method that can be called to free GPU memory.
        """
        if self.model is not None:
            del self.model
            self.model = None
        if hasattr(self, 'processor') and self.processor is not None:
            del self.processor
            self.processor = None
        if hasattr(self, 'tokenizer') and self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None
        
        self.current_model_name = None
        
        # Force cleanup
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        
        print("Memory cleanup completed.")
    
    #########################################################
    ## Load functions

    def _load_smolvlm_model(self, model_name):
        """Load SmolVLM model."""
        processor = AutoProcessor.from_pretrained(model_name)
        model = AutoModelForVision2Seq.from_pretrained(
            model_name, 
            _attn_implementation="eager"
        ).to(self.device)
        model.eval()
        return processor, model
    
    def _load_qwen25_model(self, model_name):
        """Load Qwen2.5-VL model."""
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name, torch_dtype="auto", device_map="auto"
        )

        # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
        # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        #     "Qwen/Qwen2.5-VL-7B-Instruct",
        #     torch_dtype=torch.bfloat16,
        #     attn_implementation="flash_attention_2",
        #     device_map="auto",
        # )

        processor = AutoProcessor.from_pretrained(model_name)
        model.eval()
        return processor, model
    
    def _load_internvl35_model(self, model_name):
        """Load InternVL3.5 model."""
        # Load tokenizer (InternVL uses tokenizer instead of processor for text)
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        
        # Load the model using AutoModel
        model = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
            low_cpu_mem_usage=True,
            use_flash_attn=False,                   # True set False if CUDA mismatch
            trust_remote_code=True,
            device_map="auto"
        )

        model.eval()
        
        # Return tokenizer as processor for consistency with the interface
        return tokenizer, model
    
    def _load_gemma3_model(self, model_name):
        """Load Gemma3 model."""
        quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
        model = Gemma3ForConditionalGeneration.from_pretrained(
            model_name,
            device_map="auto",
            quantization_config=quantization_config
        )
        processor = AutoProcessor.from_pretrained(model_name)
        model.eval()
        return processor, model

    #########################################################
    ## Inference functions
    def check_processor_and_model(self):
        if self.processor is None or self.model is None:
            raise ValueError("Processor and model must be loaded before generating a caption.")
    
    def _inference_qwen25_model(self, image):
        """Inference Qwen2.5-VL model."""
        self.check_processor_and_model()
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": self.system_message}]
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": Image.fromarray(image),
                    },
                    {"type": "text", "text": self.user_prompt},
                ],
            }
        ]

        # Preparation for inference
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)

        # Inference: Generation of the output
        generated_ids = self.model.generate(**inputs, max_new_tokens=128)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        caption = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        
        # Clean up tensors to free GPU memory
        del inputs, generated_ids, generated_ids_trimmed
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return caption
    
    def _inference_gemma3_model(self, image):
        """Inference Gemma3 model."""
        self.check_processor_and_model()
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": self.system_message}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": Image.fromarray(image)},
                    {"type": "text", "text": self.user_prompt}
                ]
            }
        ]
        
        inputs = self.processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        ).to(self.model.device, dtype=torch.bfloat16)

        input_len = inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
            generation = generation[0][input_len:]

        caption = self.processor.decode(generation, skip_special_tokens=True)
        
        # Clean up tensors to free GPU memory
        del inputs, generation
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return caption
    
    def _inference_smolvlm_model(self, image):
        self.check_processor_and_model()
        messages = [
            {
                "role": "system",
                "content": self.system_message
            },
            {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": self.user_prompt}
            ]
            }
        ]

        # Prepare inputs
        prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
        inputs = inputs.to(self.model.device)

        # Generate outputs
        gen_kwargs = {
        "max_new_tokens": 200,          # plenty for ~40 words
            # "early_stopping": True,         # stop at first EOS
            # "no_repeat_ngram_size": 3,      # discourage loops
            # "length_penalty": 0.8,          # slightly favor brevity
            # "eos_token_id": processor.tokenizer.eos_token_id,
            # "pad_token_id": processor.tokenizer.eos_token_id,
        }
        generated_ids = self.model.generate(**inputs, **gen_kwargs) # max_new_tokens=500)
        generated_texts = self.processor.batch_decode(
            generated_ids,
            skip_special_tokens=True,
        )[0]

        # Extract only what the assistant said
        if "Assistant:" in generated_texts:
            caption = generated_texts.split("Assistant:", 1)[1].strip()
        else:
            caption = generated_texts.strip()
        
        # Clean up tensors to free GPU memory
        del inputs, generated_ids
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return caption

    def _inference_internvl35_model(self, image):
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be loaded before generating a caption for InternVL3.5.")
        # image can be numpy (H,W,3) or PIL.Image
        if hasattr(image, "shape"):  # numpy array
            pil_image = Image.fromarray(image.astype("uint8"), mode="RGB")
        else:
            pil_image = image

        pixel_values = self._image_to_pixel_values(pil_image, size=448, max_num=12)
        pixel_values = pixel_values.to(dtype=torch.bfloat16, device=self.model.device)

        # Format question with image token (matches official docs)
        question = "<image>\n" + self.user_prompt
        
        # Generation config matching official examples
        gen_cfg = dict(
            max_new_tokens=128,
            do_sample=False,
            temperature=0.0,
            # Optional: add other parameters from docs
            # top_p=0.9,
            # repetition_penalty=1.1
        )

        # Use model's chat method (official approach)
        response = self.model.chat(self.tokenizer, pixel_values, question, gen_cfg)
        
        # Clean up tensors to free GPU memory
        del pixel_values
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return response.strip()

    def _image_to_pixel_values(self, img, size=448, max_num=12):
        transform = self._build_transform(size)
        tiles = self._dynamic_preprocess(img, image_size=size, max_num=max_num, use_thumbnail=True)
        pixel_values = torch.stack([transform(t) for t in tiles])
        return pixel_values

    
    def _dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
        # same logic as the model card: split into tiles based on aspect ratio
        w, h = image.size
        aspect = w / h
        targets = sorted({(i, j) for n in range(min_num, max_num+1)
                        for i in range(1, n+1) for j in range(1, n+1)
                        if i*j <= max_num and i*j >= min_num},
                        key=lambda x: x[0]*x[1])

        # pick closest ratio
        best = min(targets, key=lambda r: abs(aspect - r[0]/r[1]))
        tw, th = image_size * best[0], image_size * best[1]
        resized = image.resize((tw, th))

        tiles = []
        for i in range(best[0] * best[1]):
            box = ((i % (tw // image_size)) * image_size,
                (i // (tw // image_size)) * image_size,
                ((i % (tw // image_size)) + 1) * image_size,
                ((i // (tw // image_size)) + 1) * image_size)
            tiles.append(resized.crop(box))

        if use_thumbnail and len(tiles) != 1:
            tiles.append(image.resize((image_size, image_size)))
        return tiles

    def _build_transform(self, input_size=448):
        return T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])


# Global VLM Manager instance
vlm_manager = VLMManager()