Spaces:
Running
on
T4
Running
on
T4
| import torch, torchvision.transforms as T | |
| from torchvision.transforms.functional import InterpolationMode | |
| from PIL import Image | |
| from transformers import TorchAoConfig, Qwen2_5_VLForConditionalGeneration, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, AutoModel | |
| from qwen_vl_utils import process_vision_info | |
| import gc | |
| # from transformers.image_utils import load_image | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| class VLMManager: | |
| """ | |
| A manager class for Vision-Language Models that handles model loading, | |
| caching, and dynamic switching between different models. | |
| """ | |
| def __init__(self, default_model: str = "Gemma3-4B"): | |
| """ | |
| Initialize the VLM Manager with a default model. | |
| Args: | |
| default_model (str): The default model to load initially. | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.current_model_name = None | |
| self.processor = None | |
| self.tokenizer = None # Initialize tokenizer attribute | |
| self.model = None | |
| self.system_message = """ | |
| You are an expert cultural-aware image-analysis assistant. For every image: | |
| 1. Output exactly 40 words in total. | |
| 2. Use a single paragraph (no lists or bullet points). | |
| 3. Describe Who (appearance/emotion), What (action), and Where (setting). | |
| 4. Do NOT include opinions or speculations. | |
| 5. If you go over 40 words, shorten or remove non-essential details. | |
| """ | |
| self.user_prompt = """ | |
| Given this image, please provide an image description of around 40 words with extensive and detailed visual information. | |
| Descriptions must be objective: focus on how you would describe the image to someone who can't see it, without your own opinions/speculations. | |
| The text needs to include the main concept and describe the content of the image in detail by including: | |
| - Who?: The visual appearance and observable emotions (e.g., "is smiling") of persons and animals. | |
| - What?: The actions performed in the image. | |
| - Where?: The setting of the image, including the size, color, and relationships between objects. | |
| """ | |
| # Load the default model | |
| self.load_model(default_model) | |
| def load_model(self, model_name: str): | |
| """ | |
| Load a VLM model. If the model is already loaded, return the cached version. | |
| Args: | |
| model_name (str): The name of the model to load. | |
| """ | |
| # If the requested model is already loaded, no need to reload | |
| if self.current_model_name == model_name and self.model is not None: | |
| print(f"Model {model_name} is already loaded, using cached version.") | |
| if self.current_model_name == "InternVL3_5-8B": | |
| return self.tokenizer, self.model | |
| else: | |
| return self.processor, self.model | |
| print(f"Loading model: {model_name}") | |
| # Clear current model from memory if exists | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if self.current_model_name == "InternVL3_5-8B": | |
| if hasattr(self, 'tokenizer') and self.tokenizer is not None: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| else: | |
| if hasattr(self, 'processor') and self.processor is not None: | |
| del self.processor | |
| self.processor = None | |
| # Force garbage collection and clear CUDA cache | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() # Wait for all operations to complete | |
| # Load the new model | |
| if model_name == "SmolVLM-500M": | |
| self.processor, self.model = self._load_smolvlm_model("HuggingFaceTB/SmolVLM-500M-Instruct") | |
| elif model_name == "Qwen2.5-VL-7B": | |
| self.processor, self.model = self._load_qwen25_model("Qwen/Qwen2.5-VL-7B-Instruct") | |
| elif model_name == "InternVL3_5-8B": | |
| self.tokenizer, self.model = self._load_internvl35_model("OpenGVLab/InternVL3_5-8B-Instruct") | |
| elif model_name == "Gemma3-4B": | |
| self.processor, self.model = self._load_gemma3_model("google/gemma-3-4b-it") | |
| else: | |
| raise ValueError(f"Model {model_name} is not supported or not available.") | |
| self.current_model_name = model_name | |
| print(f"Successfully loaded model: {model_name}") | |
| def generate_caption(self, image): | |
| """ | |
| Generate a caption for the given image using the loaded model. | |
| Args: | |
| processor: The processor for the model. | |
| model: The model to use for generating the caption. | |
| image: The image to generate a caption for. | |
| """ | |
| if self.current_model_name == "SmolVLM-500M": | |
| return self._inference_smolvlm_model(image) | |
| elif self.current_model_name == "Qwen2.5-VL-7B": | |
| return self._inference_qwen25_model(image) | |
| elif self.current_model_name == "InternVL3_5-8B": | |
| return self._inference_internvl35_model(image) | |
| elif self.current_model_name == "Gemma3-4B": | |
| return self._inference_gemma3_model(image) | |
| else: | |
| raise ValueError(f"Model {self.current_model_name} is not supported or not available.") | |
| def get_current_model(self): | |
| """ | |
| Get the currently loaded model and processor. | |
| Returns: | |
| tuple: A tuple containing (processor, model, model_name). | |
| """ | |
| return self.processor, self.model, self.current_model_name | |
| def cleanup_memory(self): | |
| """ | |
| Explicit memory cleanup method that can be called to free GPU memory. | |
| """ | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if hasattr(self, 'processor') and self.processor is not None: | |
| del self.processor | |
| self.processor = None | |
| if hasattr(self, 'tokenizer') and self.tokenizer is not None: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| self.current_model_name = None | |
| # Force cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| print("Memory cleanup completed.") | |
| ######################################################### | |
| ## Load functions | |
| def _load_smolvlm_model(self, model_name): | |
| """Load SmolVLM model.""" | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| model_name, | |
| _attn_implementation="eager" | |
| ).to(self.device) | |
| model.eval() | |
| return processor, model | |
| def _load_qwen25_model(self, model_name): | |
| """Load Qwen2.5-VL model.""" | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| model_name, torch_dtype="auto", device_map="auto" | |
| ) | |
| # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. | |
| # model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| # "Qwen/Qwen2.5-VL-7B-Instruct", | |
| # torch_dtype=torch.bfloat16, | |
| # attn_implementation="flash_attention_2", | |
| # device_map="auto", | |
| # ) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model.eval() | |
| return processor, model | |
| def _load_internvl35_model(self, model_name): | |
| """Load InternVL3.5 model.""" | |
| # Load tokenizer (InternVL uses tokenizer instead of processor for text) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| # Load the model using AutoModel | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16, | |
| low_cpu_mem_usage=True, | |
| use_flash_attn=False, # True set False if CUDA mismatch | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| # Return tokenizer as processor for consistency with the interface | |
| return tokenizer, model | |
| def _load_gemma3_model(self, model_name): | |
| """Load Gemma3 model.""" | |
| quantization_config = TorchAoConfig("int4_weight_only", group_size=128) | |
| model = Gemma3ForConditionalGeneration.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| quantization_config=quantization_config | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model.eval() | |
| return processor, model | |
| ######################################################### | |
| ## Inference functions | |
| def check_processor_and_model(self): | |
| if self.processor is None or self.model is None: | |
| raise ValueError("Processor and model must be loaded before generating a caption.") | |
| def _inference_qwen25_model(self, image): | |
| """Inference Qwen2.5-VL model.""" | |
| self.check_processor_and_model() | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": self.system_message}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": Image.fromarray(image), | |
| }, | |
| {"type": "text", "text": self.user_prompt}, | |
| ], | |
| } | |
| ] | |
| # Preparation for inference | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(self.model.device) | |
| # Inference: Generation of the output | |
| generated_ids = self.model.generate(**inputs, max_new_tokens=128) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| caption = self.processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| # Clean up tensors to free GPU memory | |
| del inputs, generated_ids, generated_ids_trimmed | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return caption | |
| def _inference_gemma3_model(self, image): | |
| """Inference Gemma3 model.""" | |
| self.check_processor_and_model() | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": self.system_message}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": Image.fromarray(image)}, | |
| {"type": "text", "text": self.user_prompt} | |
| ] | |
| } | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt" | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
| generation = generation[0][input_len:] | |
| caption = self.processor.decode(generation, skip_special_tokens=True) | |
| # Clean up tensors to free GPU memory | |
| del inputs, generation | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return caption | |
| def _inference_smolvlm_model(self, image): | |
| self.check_processor_and_model() | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": self.system_message | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": self.user_prompt} | |
| ] | |
| } | |
| ] | |
| # Prepare inputs | |
| prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = self.processor(text=prompt, images=[image], return_tensors="pt") | |
| inputs = inputs.to(self.model.device) | |
| # Generate outputs | |
| gen_kwargs = { | |
| "max_new_tokens": 200, # plenty for ~40 words | |
| # "early_stopping": True, # stop at first EOS | |
| # "no_repeat_ngram_size": 3, # discourage loops | |
| # "length_penalty": 0.8, # slightly favor brevity | |
| # "eos_token_id": processor.tokenizer.eos_token_id, | |
| # "pad_token_id": processor.tokenizer.eos_token_id, | |
| } | |
| generated_ids = self.model.generate(**inputs, **gen_kwargs) # max_new_tokens=500) | |
| generated_texts = self.processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| )[0] | |
| # Extract only what the assistant said | |
| if "Assistant:" in generated_texts: | |
| caption = generated_texts.split("Assistant:", 1)[1].strip() | |
| else: | |
| caption = generated_texts.strip() | |
| # Clean up tensors to free GPU memory | |
| del inputs, generated_ids | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return caption | |
| def _inference_internvl35_model(self, image): | |
| if self.tokenizer is None: | |
| raise ValueError("Tokenizer must be loaded before generating a caption for InternVL3.5.") | |
| # image can be numpy (H,W,3) or PIL.Image | |
| if hasattr(image, "shape"): # numpy array | |
| pil_image = Image.fromarray(image.astype("uint8"), mode="RGB") | |
| else: | |
| pil_image = image | |
| pixel_values = self._image_to_pixel_values(pil_image, size=448, max_num=12) | |
| pixel_values = pixel_values.to(dtype=torch.bfloat16, device=self.model.device) | |
| # Format question with image token (matches official docs) | |
| question = "<image>\n" + self.user_prompt | |
| # Generation config matching official examples | |
| gen_cfg = dict( | |
| max_new_tokens=128, | |
| do_sample=False, | |
| temperature=0.0, | |
| # Optional: add other parameters from docs | |
| # top_p=0.9, | |
| # repetition_penalty=1.1 | |
| ) | |
| # Use model's chat method (official approach) | |
| response = self.model.chat(self.tokenizer, pixel_values, question, gen_cfg) | |
| # Clean up tensors to free GPU memory | |
| del pixel_values | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return response.strip() | |
| def _image_to_pixel_values(self, img, size=448, max_num=12): | |
| transform = self._build_transform(size) | |
| tiles = self._dynamic_preprocess(img, image_size=size, max_num=max_num, use_thumbnail=True) | |
| pixel_values = torch.stack([transform(t) for t in tiles]) | |
| return pixel_values | |
| def _dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True): | |
| # same logic as the model card: split into tiles based on aspect ratio | |
| w, h = image.size | |
| aspect = w / h | |
| targets = sorted({(i, j) for n in range(min_num, max_num+1) | |
| for i in range(1, n+1) for j in range(1, n+1) | |
| if i*j <= max_num and i*j >= min_num}, | |
| key=lambda x: x[0]*x[1]) | |
| # pick closest ratio | |
| best = min(targets, key=lambda r: abs(aspect - r[0]/r[1])) | |
| tw, th = image_size * best[0], image_size * best[1] | |
| resized = image.resize((tw, th)) | |
| tiles = [] | |
| for i in range(best[0] * best[1]): | |
| box = ((i % (tw // image_size)) * image_size, | |
| (i // (tw // image_size)) * image_size, | |
| ((i % (tw // image_size)) + 1) * image_size, | |
| ((i // (tw // image_size)) + 1) * image_size) | |
| tiles.append(resized.crop(box)) | |
| if use_thumbnail and len(tiles) != 1: | |
| tiles.append(image.resize((image_size, image_size))) | |
| return tiles | |
| def _build_transform(self, input_size=448): | |
| return T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| # Global VLM Manager instance | |
| vlm_manager = VLMManager() |