Spaces:
Sleeping
Sleeping
| """ | |
| VLM loader with hardware auto-detection. | |
| Auto-detects CUDA VRAM and loads appropriate model: | |
| - Moondream 2B for <16GB VRAM (Colab T4) | |
| - Qwen2.5-VL-3B for >=16GB VRAM (RTX 5060) | |
| Addresses critical model loading bugs encountered during development. | |
| """ | |
| import importlib.util | |
| from typing import Optional | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from ..config import ( | |
| MOONDREAM_ID, | |
| MOONDREAM_REVISION, | |
| QWEN_ID, | |
| QWEN_MAX_PIXELS, | |
| QWEN_MIN_PIXELS, | |
| ) | |
| def get_vram_gb() -> float: | |
| """Return total CUDA VRAM in GB for device 0, or 0.0 if CUDA unavailable. | |
| Returns: | |
| Total VRAM in gigabytes. | |
| """ | |
| if not torch.cuda.is_available(): | |
| return 0.0 | |
| return torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) | |
| def _device_map() -> dict | str: | |
| """Pick a device_map that works on both GPU and CPU hosts.""" | |
| return {"": "cuda"} if torch.cuda.is_available() else {"": "cpu"} | |
| def _has_flash_attention() -> bool: | |
| """Check whether flash-attn is installed.""" | |
| return importlib.util.find_spec("flash_attn") is not None | |
| def _has_qwen_vl_utils() -> bool: | |
| """Check whether qwen_vl_utils is installed.""" | |
| return importlib.util.find_spec("qwen_vl_utils") is not None | |
| class VLM: | |
| """Vision-Language Model loader with automatic hardware-based model selection. | |
| Loads Moondream 2B on systems with <16 GB VRAM (e.g. Colab T4) and | |
| Qwen2.5-VL-3B on systems with >=16 GB VRAM (e.g. RTX 5060). | |
| """ | |
| def __init__(self, force_model: Optional[str] = None) -> None: | |
| """Detect available VRAM and load the appropriate VLM. | |
| Args: | |
| force_model: Override auto-selection. Pass ``"moondream"`` or | |
| ``"qwen"`` to skip the VRAM threshold check. | |
| """ | |
| self.vram_gb: float = get_vram_gb() | |
| self.model_type: str = "" | |
| self.model: Optional[torch.nn.Module] = None | |
| self.processor: Optional[AutoProcessor] = None | |
| use_model = force_model or ("qwen" if self.vram_gb >= 16.0 else "moondream") | |
| if use_model == "qwen": | |
| self._load_qwen() | |
| else: | |
| self._load_moondream() | |
| # ββ Loaders ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_moondream(self) -> None: | |
| """Load Moondream 2B (for <16 GB VRAM). | |
| CRITICAL: device_map must be {"": "cuda"}, NOT "auto" β "auto" crashes | |
| on transformers 5.x due to the all_tied_weights_keys bug. | |
| Revision is pinned to "2025-06-21" for API stability. | |
| """ | |
| print(f"Loading Moondream 2B (VRAM: {self.vram_gb:.1f} GB)...") | |
| self.model_type = "moondream" | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| MOONDREAM_ID, | |
| revision=MOONDREAM_REVISION, | |
| device_map=_device_map(), | |
| trust_remote_code=True, | |
| ) | |
| if torch.cuda.is_available(): | |
| print( | |
| f" GPU memory allocated: " | |
| f"{torch.cuda.memory_allocated() / 1024**2:.0f} MB" | |
| ) | |
| def _load_qwen(self) -> None: | |
| """Load Qwen2.5-VL-3B (for >=16 GB VRAM). | |
| CRITICAL: pixel caps MUST be set on processor.image_processor (not the | |
| processor wrapper). Without this, dynamic resolution can generate 16k+ | |
| tokens and cause OOM on a 16 GB GPU. | |
| """ | |
| print(f"Loading Qwen2.5-VL-3B (VRAM: {self.vram_gb:.1f} GB)...") | |
| self.model_type = "qwen" | |
| self.processor = AutoProcessor.from_pretrained(QWEN_ID, use_fast=True) | |
| # Set pixel caps on the image_processor sub-object β this is where | |
| # Qwen2_5_VLImageProcessor reads them during __call__. | |
| self.processor.image_processor.min_pixels = QWEN_MIN_PIXELS | |
| self.processor.image_processor.max_pixels = QWEN_MAX_PIXELS | |
| load_kwargs: dict = { | |
| "device_map": _device_map(), | |
| "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, | |
| } | |
| if torch.cuda.is_available() and _has_flash_attention(): | |
| load_kwargs["attn_implementation"] = "flash_attention_2" | |
| from transformers import AutoModelForVision2Seq # only available in transformers <5 | |
| self.model = AutoModelForVision2Seq.from_pretrained(QWEN_ID, **load_kwargs) | |
| if torch.cuda.is_available(): | |
| print( | |
| f" GPU memory allocated: " | |
| f"{torch.cuda.memory_allocated() / 1024**2:.0f} MB" | |
| ) | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def query_vlm(self, image: Image.Image, question: str) -> str: | |
| """Query the loaded VLM with an image and a text question. | |
| Args: | |
| image: PIL Image to analyse. | |
| question: Text question or prompt about the image. | |
| Returns: | |
| The model's text response (prompt stripped). | |
| """ | |
| if self.model_type == "moondream": | |
| return self._query_moondream(image, question) | |
| if self.model_type == "qwen": | |
| return self._query_qwen(image, question) | |
| raise RuntimeError(f"No VLM loaded (model_type={self.model_type!r})") | |
| def _query_moondream(self, image: Image.Image, question: str) -> str: | |
| """Query Moondream using the new query() API. | |
| Args: | |
| image: PIL Image. | |
| question: Text question. | |
| Returns: | |
| Answer string from model.query()["answer"]. | |
| """ | |
| with torch.inference_mode(): | |
| return self.model.query(image, question)["answer"] | |
| def _query_qwen(self, image: Image.Image, question: str) -> str: | |
| """Query Qwen2.5-VL using the chat-template pipeline. | |
| Args: | |
| image: PIL Image. | |
| question: Text question. | |
| Returns: | |
| Generated answer text (input tokens stripped). | |
| """ | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": question}, | |
| ], | |
| } | |
| ] | |
| # tokenize=False so we get a plain string; processor does tokenisation. | |
| text: str = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Use process_vision_info from qwen_vl_utils when available β it | |
| # correctly handles the image/video extraction from the message dict. | |
| if _has_qwen_vl_utils(): | |
| from qwen_vl_utils import process_vision_info # type: ignore[import] | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| else: | |
| image_inputs = [image] | |
| video_inputs = None | |
| processor_kwargs: dict = dict( | |
| text=[text], | |
| images=image_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| if video_inputs is not None: | |
| processor_kwargs["videos"] = video_inputs | |
| inputs = self.processor(**processor_kwargs).to( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| with torch.inference_mode(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=False, | |
| ) | |
| # Trim the input tokens so batch_decode only sees new tokens. | |
| trimmed = [ | |
| out[len(inp):] | |
| for inp, out in zip(inputs.input_ids, generated_ids) | |
| ] | |
| return self.processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |