gyubin02's picture
Initial commit
da3fe02
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
logger = logging.getLogger("labeler")
@dataclass
class ModelConfig:
model_id: str
device: str
precision: str
max_new_tokens: int
load_4bit: bool
class LabelerModel:
def __init__(self, config: ModelConfig) -> None:
self.config = config
self.device = _resolve_device(config.device)
self.dtype = _resolve_dtype(config.precision, self.device)
quantization_config = None
load_kwargs: dict[str, object] = {}
if config.load_4bit:
try:
from transformers import BitsAndBytesConfig
except ImportError as exc:
raise RuntimeError("bitsandbytes is required for 4-bit loading") from exc
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=self.dtype,
)
load_kwargs["quantization_config"] = quantization_config
load_kwargs["device_map"] = "auto"
elif self.device.startswith("cuda"):
load_kwargs["device_map"] = "auto"
self.processor = AutoProcessor.from_pretrained(config.model_id)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model_id,
torch_dtype=self.dtype,
low_cpu_mem_usage=True,
**load_kwargs,
)
if not load_kwargs.get("device_map"):
self.model.to(self.device)
self.model.eval()
def generate_texts(
self,
messages_list: list[list[dict[str, object]]],
images: Optional[list[object]],
) -> list[str]:
prompts = [
self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
for messages in messages_list
]
if images is None:
inputs = self.processor(
text=prompts,
padding=True,
return_tensors="pt",
)
else:
inputs = self.processor(
text=prompts,
images=images,
padding=True,
return_tensors="pt",
)
inputs = _move_to_device(inputs, self.model.device)
with torch.inference_mode():
output_ids = self.model.generate(
**inputs,
max_new_tokens=self.config.max_new_tokens,
do_sample=False,
)
prompt_length = inputs["input_ids"].shape[1]
generated_ids = output_ids[:, prompt_length:]
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)
def _resolve_device(device: str) -> str:
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return device
def _resolve_dtype(precision: str, device: str) -> torch.dtype:
if precision == "fp32":
return torch.float32
if precision == "bf16":
if device.startswith("cuda") and torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
if precision == "fp16":
return torch.float16
if device.startswith("cuda"):
return torch.float16
return torch.float32
def _move_to_device(inputs: dict[str, object], device: torch.device | str) -> dict[str, object]:
moved = {}
for key, value in inputs.items():
if hasattr(value, "to"):
moved[key] = value.to(device)
else:
moved[key] = value
return moved