Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import Iterable, Optional | |
| import torch | |
| from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
| logger = logging.getLogger("labeler") | |
| class ModelConfig: | |
| model_id: str | |
| device: str | |
| precision: str | |
| max_new_tokens: int | |
| load_4bit: bool | |
| class LabelerModel: | |
| def __init__(self, config: ModelConfig) -> None: | |
| self.config = config | |
| self.device = _resolve_device(config.device) | |
| self.dtype = _resolve_dtype(config.precision, self.device) | |
| quantization_config = None | |
| load_kwargs: dict[str, object] = {} | |
| if config.load_4bit: | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| except ImportError as exc: | |
| raise RuntimeError("bitsandbytes is required for 4-bit loading") from exc | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=self.dtype, | |
| ) | |
| load_kwargs["quantization_config"] = quantization_config | |
| load_kwargs["device_map"] = "auto" | |
| elif self.device.startswith("cuda"): | |
| load_kwargs["device_map"] = "auto" | |
| self.processor = AutoProcessor.from_pretrained(config.model_id) | |
| self.model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| config.model_id, | |
| torch_dtype=self.dtype, | |
| low_cpu_mem_usage=True, | |
| **load_kwargs, | |
| ) | |
| if not load_kwargs.get("device_map"): | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def generate_texts( | |
| self, | |
| messages_list: list[list[dict[str, object]]], | |
| images: Optional[list[object]], | |
| ) -> list[str]: | |
| prompts = [ | |
| self.processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| for messages in messages_list | |
| ] | |
| if images is None: | |
| inputs = self.processor( | |
| text=prompts, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| else: | |
| inputs = self.processor( | |
| text=prompts, | |
| images=images, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = _move_to_device(inputs, self.model.device) | |
| with torch.inference_mode(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=self.config.max_new_tokens, | |
| do_sample=False, | |
| ) | |
| prompt_length = inputs["input_ids"].shape[1] | |
| generated_ids = output_ids[:, prompt_length:] | |
| return self.processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| def _resolve_device(device: str) -> str: | |
| if device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return device | |
| def _resolve_dtype(precision: str, device: str) -> torch.dtype: | |
| if precision == "fp32": | |
| return torch.float32 | |
| if precision == "bf16": | |
| if device.startswith("cuda") and torch.cuda.is_bf16_supported(): | |
| return torch.bfloat16 | |
| return torch.float16 | |
| if precision == "fp16": | |
| return torch.float16 | |
| if device.startswith("cuda"): | |
| return torch.float16 | |
| return torch.float32 | |
| def _move_to_device(inputs: dict[str, object], device: torch.device | str) -> dict[str, object]: | |
| moved = {} | |
| for key, value in inputs.items(): | |
| if hasattr(value, "to"): | |
| moved[key] = value.to(device) | |
| else: | |
| moved[key] = value | |
| return moved | |