File size: 13,030 Bytes
0b7ec27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
"""
Processor class for Phi4-Siglip.

This module provides:
- Phi4VisionRProcessor: Combined tokenizer and image processor
- Utility functions for image and text processing
"""

from typing import List, Optional, Union

import torch
from PIL import Image
from transformers import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
from transformers.utils import TensorType

# Constants (duplicated here to avoid circular imports when running scripts directly)
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"


# =============================================================================
# Image Utilities
# =============================================================================

def process_images(images: List[Image.Image], image_processor, model_cfg=None):
    """
    Process images for the model.
    
    Args:
        images: List of PIL images
        image_processor: The image processor (Siglip2ImageProcessorNoUpscale for NaFlex)
        model_cfg: Optional model config (unused, kept for API compatibility)
        
    Returns:
        Processed images as BatchFeature (for NaFlex)
    """
    # Check if NaFlex (has max_num_patches attribute)
    is_naflex = hasattr(image_processor, "max_num_patches")
    
    # Process with image processor
    if is_naflex:
        return image_processor(images, return_tensors='pt')
    else:
        return image_processor(images, return_tensors='pt')['pixel_values']


# =============================================================================
# Tokenizer Utilities
# =============================================================================

def tokenizer_image_token(
    prompt: str, 
    tokenizer, 
    image_token_index: int = IMAGE_TOKEN_INDEX, 
    return_tensors: Optional[str] = None
):
    """
    Tokenize a prompt containing <image> tokens.
    
    Replaces <image> with IMAGE_TOKEN_INDEX in the token sequence.
    
    Args:
        prompt: The text prompt with <image> placeholders
        tokenizer: The tokenizer to use
        image_token_index: The index to use for image tokens
        return_tensors: If 'pt', return as PyTorch tensor
        
    Returns:
        List of token ids or tensor
    """
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])

    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors is not None:
        if return_tensors == 'pt':
            return torch.tensor(input_ids, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')
    return input_ids


# =============================================================================
# Main Processor Class
# =============================================================================

class Phi4VisionRProcessor(ProcessorMixin):
    """
    Processor for Phi4-Siglip that wraps an image processor and tokenizer.
    
    This processor handles:
    - Image preprocessing (via SigLIP or SigLIP2/NaFlex)
    - Text tokenization with image token insertion
    - Conversation formatting
    
    Args:
        image_processor: The image processor (from vision tower)
        tokenizer: The text tokenizer
    """
    
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"
    
    def __init__(self, image_processor, tokenizer):
        self.image_processor = image_processor
        self.tokenizer = tokenizer

    def __call__(
        self,
        text: Union[TextInput, List[TextInput]] = None,
        images: ImageInput = None,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
        **kwargs,
    ) -> BatchFeature:
        """
        Process text and images for the model.
        
        Args:
            text: The text input(s). Can contain <image> tokens.
            images: The image input(s).
            padding: Padding strategy.
            truncation: Whether to truncate.
            max_length: Maximum sequence length.
            return_tensors: Return type for tensors.
            
        Returns:
            BatchFeature with input_ids, attention_mask, and optionally pixel_values.
        """
        # Process images
        if images is not None:
            if not isinstance(images, list):
                images = [images]
            image_inputs = process_images(images, self.image_processor)
        else:
            image_inputs = None
        
        # Process text
        if text is not None:
            if isinstance(text, str):
                text = [text]
            
            # Check if text contains image tokens
            has_images = any(DEFAULT_IMAGE_TOKEN in t for t in text)
            
            if has_images and images is not None:
                # Tokenize with image token handling
                input_ids_list = []
                for t in text:
                    ids = tokenizer_image_token(t, self.tokenizer, return_tensors='pt')
                    input_ids_list.append(ids)
                
                # Pad sequences
                if len(input_ids_list) > 1:
                    max_len = max(len(ids) for ids in input_ids_list)
                    padded_ids = []
                    attention_masks = []
                    pad_token_id = self.tokenizer.pad_token_id or 0
                    
                    for ids in input_ids_list:
                        pad_len = max_len - len(ids)
                        if padding and pad_len > 0:
                            padded_ids.append(torch.cat([ids, torch.full((pad_len,), pad_token_id, dtype=torch.long)]))
                            attention_masks.append(torch.cat([torch.ones(len(ids)), torch.zeros(pad_len)]))
                        else:
                            padded_ids.append(ids)
                            attention_masks.append(torch.ones(len(ids)))
                    
                    input_ids = torch.stack(padded_ids)
                    attention_mask = torch.stack(attention_masks).long()
                else:
                    input_ids = input_ids_list[0].unsqueeze(0)
                    attention_mask = torch.ones_like(input_ids)
            else:
                # Standard tokenization
                text_inputs = self.tokenizer(
                    text,
                    padding=padding,
                    truncation=truncation,
                    max_length=max_length,
                    return_tensors=return_tensors,
                )
                input_ids = text_inputs["input_ids"]
                attention_mask = text_inputs["attention_mask"]
        else:
            input_ids = None
            attention_mask = None

        # Build output
        data = {}
        if input_ids is not None:
            data["input_ids"] = input_ids
            data["attention_mask"] = attention_mask
        
        if image_inputs is not None:
            if isinstance(image_inputs, BatchFeature):
                # NaFlex case - merge all fields
                data.update(image_inputs)
            else:
                data["pixel_values"] = image_inputs
                
        return BatchFeature(data=data, tensor_type=return_tensors)

    def batch_decode(self, *args, **kwargs):
        """Decode token ids to text. Forwards to tokenizer."""
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """Decode token ids to text. Forwards to tokenizer."""
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        """Get model input names from tokenizer and image processor."""
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = getattr(
            self.image_processor, 
            'model_input_names', 
            ["pixel_values"]
        )
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        """
        Load processor from a pretrained model path.
        
        This will load the tokenizer and create the appropriate image processor
        based on the model config.
        """
        from transformers import AutoTokenizer, AutoConfig
        
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
        
        # Try to load config to determine vision tower type
        try:
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
            vision_tower_name = getattr(config, 'mm_vision_tower', None)
            vision_config = getattr(config, 'vision_config', None)
            
            if vision_tower_name and 'naflex' in vision_tower_name.lower():
                from .modeling_phi4_visionr import Siglip2ImageProcessorNoUpscale
                # Use embedded vision_config to avoid network calls
                # Infer patch_size from model name if not in config (patch14 vs patch16)
                if vision_config is not None:
                    if 'patch_size' in vision_config:
                        patch_size = vision_config['patch_size']
                    elif 'patch14' in vision_tower_name.lower():
                        patch_size = 14
                    else:
                        patch_size = 16  # default for patch16-naflex
                    image_processor = Siglip2ImageProcessorNoUpscale(
                        patch_size=patch_size,
                        max_num_patches=getattr(config, 'max_num_patches', 3600),
                        min_num_patches=getattr(config, 'min_num_patches', 256),
                    )
                else:
                    image_processor = Siglip2ImageProcessorNoUpscale.from_pretrained(
                        vision_tower_name,
                        max_num_patches=getattr(config, 'max_num_patches', 3600),
                        min_num_patches=getattr(config, 'min_num_patches', 256),
                    )
            elif vision_tower_name:
                from transformers import SiglipImageProcessor
                # Use embedded vision_config to avoid network calls
                if vision_config is not None:
                    image_processor = SiglipImageProcessor(
                        size={"height": vision_config.get('image_size', 384), "width": vision_config.get('image_size', 384)},
                    )
                else:
                    image_processor = SiglipImageProcessor.from_pretrained(vision_tower_name)
            else:
                image_processor = None
        except Exception:
            image_processor = None
            
        return cls(image_processor=image_processor, tokenizer=tokenizer)


# =============================================================================
# Convenience Functions
# =============================================================================

def prepare_inputs_for_generation(
    prompt: str,
    images: Optional[List[Image.Image]],
    processor: Phi4VisionRProcessor,
    device: str = "cuda",
    dtype: torch.dtype = torch.bfloat16,
) -> dict:
    """
    Prepare inputs for model generation.
    
    Args:
        prompt: The user prompt (without conversation formatting)
        images: Optional list of PIL images
        processor: The Phi4VisionRProcessor
        device: Device to place tensors on
        dtype: Data type for tensors
        
    Returns:
        Dictionary with model inputs
    """
    # Add image token to prompt if images provided
    if images:
        prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
    
    # Use tokenizer's chat_template
    messages = [{"role": "user", "content": prompt}]
    full_prompt = processor.tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    inputs = processor(
        text=full_prompt,
        images=images,
        return_tensors="pt",
    )
    
    # Move to device
    for key in inputs:
        if isinstance(inputs[key], torch.Tensor):
            inputs[key] = inputs[key].to(device=device, dtype=dtype if inputs[key].is_floating_point() else inputs[key].dtype)
    
    return inputs