import torch, torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from PIL import Image from transformers import TorchAoConfig, Qwen2_5_VLForConditionalGeneration, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, AutoModel from qwen_vl_utils import process_vision_info import gc # from transformers.image_utils import load_image IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) class VLMManager: """ A manager class for Vision-Language Models that handles model loading, caching, and dynamic switching between different models. """ def __init__(self, default_model: str = "Gemma3-4B"): """ Initialize the VLM Manager with a default model. Args: default_model (str): The default model to load initially. """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.current_model_name = None self.processor = None self.tokenizer = None # Initialize tokenizer attribute self.model = None self.system_message = """ You are an expert cultural-aware image-analysis assistant. For every image: 1. Output exactly 40 words in total. 2. Use a single paragraph (no lists or bullet points). 3. Describe Who (appearance/emotion), What (action), and Where (setting). 4. Do NOT include opinions or speculations. 5. If you go over 40 words, shorten or remove non-essential details. """ self.user_prompt = """ Given this image, please provide an image description of around 40 words with extensive and detailed visual information. Descriptions must be objective: focus on how you would describe the image to someone who can't see it, without your own opinions/speculations. The text needs to include the main concept and describe the content of the image in detail by including: - Who?: The visual appearance and observable emotions (e.g., "is smiling") of persons and animals. - What?: The actions performed in the image. - Where?: The setting of the image, including the size, color, and relationships between objects. """ # Load the default model self.load_model(default_model) def load_model(self, model_name: str): """ Load a VLM model. If the model is already loaded, return the cached version. Args: model_name (str): The name of the model to load. """ # If the requested model is already loaded, no need to reload if self.current_model_name == model_name and self.model is not None: print(f"Model {model_name} is already loaded, using cached version.") if self.current_model_name == "InternVL3_5-8B": return self.tokenizer, self.model else: return self.processor, self.model print(f"Loading model: {model_name}") # Clear current model from memory if exists if self.model is not None: del self.model self.model = None if self.current_model_name == "InternVL3_5-8B": if hasattr(self, 'tokenizer') and self.tokenizer is not None: del self.tokenizer self.tokenizer = None else: if hasattr(self, 'processor') and self.processor is not None: del self.processor self.processor = None # Force garbage collection and clear CUDA cache gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Wait for all operations to complete # Load the new model if model_name == "SmolVLM-500M": self.processor, self.model = self._load_smolvlm_model("HuggingFaceTB/SmolVLM-500M-Instruct") elif model_name == "Qwen2.5-VL-7B": self.processor, self.model = self._load_qwen25_model("Qwen/Qwen2.5-VL-7B-Instruct") elif model_name == "InternVL3_5-8B": self.tokenizer, self.model = self._load_internvl35_model("OpenGVLab/InternVL3_5-8B-Instruct") elif model_name == "Gemma3-4B": self.processor, self.model = self._load_gemma3_model("google/gemma-3-4b-it") else: raise ValueError(f"Model {model_name} is not supported or not available.") self.current_model_name = model_name print(f"Successfully loaded model: {model_name}") def generate_caption(self, image): """ Generate a caption for the given image using the loaded model. Args: processor: The processor for the model. model: The model to use for generating the caption. image: The image to generate a caption for. """ if self.current_model_name == "SmolVLM-500M": return self._inference_smolvlm_model(image) elif self.current_model_name == "Qwen2.5-VL-7B": return self._inference_qwen25_model(image) elif self.current_model_name == "InternVL3_5-8B": return self._inference_internvl35_model(image) elif self.current_model_name == "Gemma3-4B": return self._inference_gemma3_model(image) else: raise ValueError(f"Model {self.current_model_name} is not supported or not available.") def get_current_model(self): """ Get the currently loaded model and processor. Returns: tuple: A tuple containing (processor, model, model_name). """ return self.processor, self.model, self.current_model_name def cleanup_memory(self): """ Explicit memory cleanup method that can be called to free GPU memory. """ if self.model is not None: del self.model self.model = None if hasattr(self, 'processor') and self.processor is not None: del self.processor self.processor = None if hasattr(self, 'tokenizer') and self.tokenizer is not None: del self.tokenizer self.tokenizer = None self.current_model_name = None # Force cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() print("Memory cleanup completed.") ######################################################### ## Load functions def _load_smolvlm_model(self, model_name): """Load SmolVLM model.""" processor = AutoProcessor.from_pretrained(model_name) model = AutoModelForVision2Seq.from_pretrained( model_name, _attn_implementation="eager" ).to(self.device) model.eval() return processor, model def _load_qwen25_model(self, model_name): """Load Qwen2.5-VL model.""" model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_name, torch_dtype="auto", device_map="auto" ) # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. # model = Qwen2_5_VLForConditionalGeneration.from_pretrained( # "Qwen/Qwen2.5-VL-7B-Instruct", # torch_dtype=torch.bfloat16, # attn_implementation="flash_attention_2", # device_map="auto", # ) processor = AutoProcessor.from_pretrained(model_name) model.eval() return processor, model def _load_internvl35_model(self, model_name): """Load InternVL3.5 model.""" # Load tokenizer (InternVL uses tokenizer instead of processor for text) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Load the model using AutoModel model = AutoModel.from_pretrained( model_name, torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16, low_cpu_mem_usage=True, use_flash_attn=False, # True set False if CUDA mismatch trust_remote_code=True, device_map="auto" ) model.eval() # Return tokenizer as processor for consistency with the interface return tokenizer, model def _load_gemma3_model(self, model_name): """Load Gemma3 model.""" quantization_config = TorchAoConfig("int4_weight_only", group_size=128) model = Gemma3ForConditionalGeneration.from_pretrained( model_name, device_map="auto", quantization_config=quantization_config ) processor = AutoProcessor.from_pretrained(model_name) model.eval() return processor, model ######################################################### ## Inference functions def check_processor_and_model(self): if self.processor is None or self.model is None: raise ValueError("Processor and model must be loaded before generating a caption.") def _inference_qwen25_model(self, image): """Inference Qwen2.5-VL model.""" self.check_processor_and_model() messages = [ { "role": "system", "content": [{"type": "text", "text": self.system_message}] }, { "role": "user", "content": [ { "type": "image", "image": Image.fromarray(image), }, {"type": "text", "text": self.user_prompt}, ], } ] # Preparation for inference text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(self.model.device) # Inference: Generation of the output generated_ids = self.model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] caption = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Clean up tensors to free GPU memory del inputs, generated_ids, generated_ids_trimmed if torch.cuda.is_available(): torch.cuda.empty_cache() return caption def _inference_gemma3_model(self, image): """Inference Gemma3 model.""" self.check_processor_and_model() messages = [ { "role": "system", "content": [{"type": "text", "text": self.system_message}] }, { "role": "user", "content": [ {"type": "image", "image": Image.fromarray(image)}, {"type": "text", "text": self.user_prompt} ] } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(self.model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False) generation = generation[0][input_len:] caption = self.processor.decode(generation, skip_special_tokens=True) # Clean up tensors to free GPU memory del inputs, generation if torch.cuda.is_available(): torch.cuda.empty_cache() return caption def _inference_smolvlm_model(self, image): self.check_processor_and_model() messages = [ { "role": "system", "content": self.system_message }, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": self.user_prompt} ] } ] # Prepare inputs prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) inputs = self.processor(text=prompt, images=[image], return_tensors="pt") inputs = inputs.to(self.model.device) # Generate outputs gen_kwargs = { "max_new_tokens": 200, # plenty for ~40 words # "early_stopping": True, # stop at first EOS # "no_repeat_ngram_size": 3, # discourage loops # "length_penalty": 0.8, # slightly favor brevity # "eos_token_id": processor.tokenizer.eos_token_id, # "pad_token_id": processor.tokenizer.eos_token_id, } generated_ids = self.model.generate(**inputs, **gen_kwargs) # max_new_tokens=500) generated_texts = self.processor.batch_decode( generated_ids, skip_special_tokens=True, )[0] # Extract only what the assistant said if "Assistant:" in generated_texts: caption = generated_texts.split("Assistant:", 1)[1].strip() else: caption = generated_texts.strip() # Clean up tensors to free GPU memory del inputs, generated_ids if torch.cuda.is_available(): torch.cuda.empty_cache() return caption def _inference_internvl35_model(self, image): if self.tokenizer is None: raise ValueError("Tokenizer must be loaded before generating a caption for InternVL3.5.") # image can be numpy (H,W,3) or PIL.Image if hasattr(image, "shape"): # numpy array pil_image = Image.fromarray(image.astype("uint8"), mode="RGB") else: pil_image = image pixel_values = self._image_to_pixel_values(pil_image, size=448, max_num=12) pixel_values = pixel_values.to(dtype=torch.bfloat16, device=self.model.device) # Format question with image token (matches official docs) question = "\n" + self.user_prompt # Generation config matching official examples gen_cfg = dict( max_new_tokens=128, do_sample=False, temperature=0.0, # Optional: add other parameters from docs # top_p=0.9, # repetition_penalty=1.1 ) # Use model's chat method (official approach) response = self.model.chat(self.tokenizer, pixel_values, question, gen_cfg) # Clean up tensors to free GPU memory del pixel_values if torch.cuda.is_available(): torch.cuda.empty_cache() return response.strip() def _image_to_pixel_values(self, img, size=448, max_num=12): transform = self._build_transform(size) tiles = self._dynamic_preprocess(img, image_size=size, max_num=max_num, use_thumbnail=True) pixel_values = torch.stack([transform(t) for t in tiles]) return pixel_values def _dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True): # same logic as the model card: split into tiles based on aspect ratio w, h = image.size aspect = w / h targets = sorted({(i, j) for n in range(min_num, max_num+1) for i in range(1, n+1) for j in range(1, n+1) if i*j <= max_num and i*j >= min_num}, key=lambda x: x[0]*x[1]) # pick closest ratio best = min(targets, key=lambda r: abs(aspect - r[0]/r[1])) tw, th = image_size * best[0], image_size * best[1] resized = image.resize((tw, th)) tiles = [] for i in range(best[0] * best[1]): box = ((i % (tw // image_size)) * image_size, (i // (tw // image_size)) * image_size, ((i % (tw // image_size)) + 1) * image_size, ((i // (tw // image_size)) + 1) * image_size) tiles.append(resized.crop(box)) if use_thumbnail and len(tiles) != 1: tiles.append(image.resize((image_size, image_size))) return tiles def _build_transform(self, input_size=448): return T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) # Global VLM Manager instance vlm_manager = VLMManager()