File size: 3,243 Bytes
c119316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Combined image/text processor for FlexiCT-3D-VLM."""

from __future__ import annotations

from pathlib import Path
from typing import Any

from transformers import AutoTokenizer, BatchFeature, ProcessorMixin

from .image_processing_flexict import FlexiCTImageProcessor


def _has_local_tokenizer_files(path: str | Path) -> bool:
    directory = Path(path)
    return any((directory / name).exists() for name in ("tokenizer.json", "tokenizer_config.json", "vocab.json"))


class FlexiCTVLMProcessor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "FlexiCTImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(
        self,
        image_processor: FlexiCTImageProcessor,
        tokenizer,
        text_model_id: str = "Qwen/Qwen3-Embedding-0.6B",
        max_length: int = 8192,
        **kwargs: Any,
    ):
        self.text_model_id = text_model_id
        self.max_length = int(max_length)
        tokenizer.padding_side = "left"
        tokenizer.model_max_length = self.max_length
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.chat_template = kwargs.pop("chat_template", None)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        trust_remote_code = kwargs.pop("trust_remote_code", True)
        text_model_id = kwargs.pop("text_model_id", None)
        max_length = int(kwargs.pop("max_length", 8192))
        local_files_only = kwargs.get("local_files_only", False)

        image_processor = FlexiCTImageProcessor.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
        text_model_id = text_model_id or getattr(image_processor, "text_model_id", "Qwen/Qwen3-Embedding-0.6B")
        tokenizer_source = pretrained_model_name_or_path if _has_local_tokenizer_files(pretrained_model_name_or_path) else text_model_id
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_source,
            padding_side="left",
            model_max_length=max_length,
            trust_remote_code=trust_remote_code,
            local_files_only=local_files_only,
        )
        return cls(image_processor=image_processor, tokenizer=tokenizer, text_model_id=text_model_id, max_length=max_length)

    def __call__(
        self,
        images=None,
        text: str | list[str] | None = None,
        return_tensors: str | None = "pt",
        **kwargs: Any,
    ) -> BatchFeature:
        data: dict[str, Any] = {}
        if images is not None:
            image_kwargs = {k: v for k, v in kwargs.items() if k not in {"padding", "truncation", "max_length"}}
            data.update(self.image_processor(images, return_tensors=return_tensors, **image_kwargs))
        if text is not None:
            text_batch = self.tokenizer(
                text,
                padding=kwargs.pop("padding", True),
                truncation=kwargs.pop("truncation", True),
                max_length=int(kwargs.pop("max_length", self.max_length)),
                return_tensors=return_tensors,
            )
            data.update(text_batch)
        return BatchFeature(data=data)