"""Combined image/text processor for FlexiCT-3D-VLM.""" from __future__ import annotations from pathlib import Path from typing import Any from transformers import AutoTokenizer, BatchFeature, ProcessorMixin from .image_processing_flexict import FlexiCTImageProcessor def _has_local_tokenizer_files(path: str | Path) -> bool: directory = Path(path) return any((directory / name).exists() for name in ("tokenizer.json", "tokenizer_config.json", "vocab.json")) class FlexiCTVLMProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "FlexiCTImageProcessor" tokenizer_class = "AutoTokenizer" def __init__( self, image_processor: FlexiCTImageProcessor, tokenizer, text_model_id: str = "Qwen/Qwen3-Embedding-0.6B", max_length: int = 8192, **kwargs: Any, ): self.text_model_id = text_model_id self.max_length = int(max_length) tokenizer.padding_side = "left" tokenizer.model_max_length = self.max_length self.image_processor = image_processor self.tokenizer = tokenizer self.chat_template = kwargs.pop("chat_template", None) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): trust_remote_code = kwargs.pop("trust_remote_code", True) text_model_id = kwargs.pop("text_model_id", None) max_length = int(kwargs.pop("max_length", 8192)) local_files_only = kwargs.get("local_files_only", False) image_processor = FlexiCTImageProcessor.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs, ) text_model_id = text_model_id or getattr(image_processor, "text_model_id", "Qwen/Qwen3-Embedding-0.6B") tokenizer_source = pretrained_model_name_or_path if _has_local_tokenizer_files(pretrained_model_name_or_path) else text_model_id tokenizer = AutoTokenizer.from_pretrained( tokenizer_source, padding_side="left", model_max_length=max_length, trust_remote_code=trust_remote_code, local_files_only=local_files_only, ) return cls(image_processor=image_processor, tokenizer=tokenizer, text_model_id=text_model_id, max_length=max_length) def __call__( self, images=None, text: str | list[str] | None = None, return_tensors: str | None = "pt", **kwargs: Any, ) -> BatchFeature: data: dict[str, Any] = {} if images is not None: image_kwargs = {k: v for k, v in kwargs.items() if k not in {"padding", "truncation", "max_length"}} data.update(self.image_processor(images, return_tensors=return_tensors, **image_kwargs)) if text is not None: text_batch = self.tokenizer( text, padding=kwargs.pop("padding", True), truncation=kwargs.pop("truncation", True), max_length=int(kwargs.pop("max_length", self.max_length)), return_tensors=return_tensors, ) data.update(text_batch) return BatchFeature(data=data)