| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from transformers import ProcessorMixin |
| | from typing import List, Union, Dict, Any, Optional |
| | import torch |
| |
|
| |
|
| | class XVLAProcessor(ProcessorMixin): |
| | """ |
| | XVLAProcessor: Unified multimodal processor for XVLA models. |
| | |
| | Handles: |
| | - Multi-view image inputs (e.g., from multiple cameras). |
| | - Batch processing for multiple samples. |
| | - Joint tokenization and image tensor preparation. |
| | |
| | This processor combines an image processor and a tokenizer under a single interface |
| | so that users can call it directly like: |
| | |
| | >>> processor = XVLAProcessor.from_pretrained("path/to/xvla") |
| | >>> inputs = processor(images=batch_images, language_instruction=batch_texts) |
| | |
| | It is fully compatible with the Hugging Face AutoProcessor API. |
| | |
| | Attributes |
| | ---------- |
| | num_views : int, default=3 |
| | Expected number of image views per sample. Missing views will be padded with zeros. |
| | language_max_length : int, default=50 |
| | Maximum token length for text encoding. |
| | attributes : list |
| | Required by ProcessorMixin to know which submodules are stored and reloaded. |
| | image_processor_class : str |
| | The name of the associated image processor class. |
| | tokenizer_class : tuple(str) |
| | The names of compatible tokenizer classes. |
| | """ |
| |
|
| | num_views: int = 3 |
| | language_max_length: int = 50 |
| |
|
| | |
| | attributes = ["image_processor", "tokenizer"] |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = ("BartTokenizer", "BartTokenizerFast") |
| |
|
| | def __init__(self, image_processor=None, tokenizer=None): |
| | """ |
| | Initialize XVLAProcessor. |
| | |
| | Parameters |
| | ---------- |
| | image_processor : PreTrainedImageProcessor, optional |
| | The image processor used to normalize/resize images. |
| | tokenizer : PreTrainedTokenizer, optional |
| | The tokenizer used for text tokenization. |
| | """ |
| | |
| | super().__init__(image_processor, tokenizer) |
| |
|
| | |
| | def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Tokenize one or more language instructions. |
| | |
| | Parameters |
| | ---------- |
| | language_instruction : str or List[str] |
| | A single instruction or a batch of instructions. |
| | |
| | Returns |
| | ------- |
| | Dict[str, torch.Tensor] |
| | { |
| | "input_ids": tensor of shape [B, L] |
| | } |
| | """ |
| | if isinstance(language_instruction, str): |
| | language_instruction = [language_instruction] |
| |
|
| | inputs = self.tokenizer( |
| | language_instruction, |
| | return_tensors="pt", |
| | padding="max_length", |
| | max_length=self.language_max_length, |
| | truncation=True, |
| | ) |
| | return {"input_ids": inputs["input_ids"]} |
| |
|
| | |
| | def encode_image( |
| | self, |
| | images: Union[List, List[List]], |
| | **kwargs |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Preprocess one or more sets of multi-view images. |
| | |
| | Parameters |
| | ---------- |
| | images : List or List[List] |
| | Single sample: [img1, img2, ...] |
| | Batch: [[img1a, img1b], [img2a, img2b, img2c], ...] |
| | Each image may be a PIL.Image, NumPy array, or torch.Tensor. |
| | |
| | kwargs : dict |
| | Extra arguments passed to the underlying image processor |
| | (e.g., `do_resize=False`, `size=(224,224)`). |
| | |
| | Returns |
| | ------- |
| | Dict[str, torch.Tensor] |
| | { |
| | "image_input": tensor [B, num_views, C, H, W], |
| | "image_mask": tensor [B, num_views] |
| | } |
| | """ |
| | |
| | if not isinstance(images[0], (list, tuple)): |
| | images = [images] |
| |
|
| | batch_imgs, batch_masks = [], [] |
| |
|
| | for sample_imgs in images: |
| | processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"] |
| | V_exist = processed.size(0) |
| |
|
| | |
| | if V_exist < self.num_views: |
| | processed = torch.cat( |
| | [processed, |
| | processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])], |
| | dim=0, |
| | ) |
| |
|
| | |
| | image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device) |
| | image_mask[:V_exist] = True |
| |
|
| | batch_imgs.append(processed) |
| | batch_masks.append(image_mask) |
| |
|
| | image_input = torch.stack(batch_imgs, dim=0) |
| | image_mask = torch.stack(batch_masks, dim=0) |
| |
|
| | return {"image_input": image_input, "image_mask": image_mask} |
| |
|
| | |
| | def __call__( |
| | self, |
| | images: Optional[Union[List, List[List]]] = None, |
| | language_instruction: Optional[Union[str, List[str]]] = None, |
| | **kwargs |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Combine image and text encoding into a unified multimodal input. |
| | |
| | Parameters |
| | ---------- |
| | images : List or List[List], optional |
| | Single-sample or batched multi-view images. |
| | language_instruction : str or List[str], optional |
| | Corresponding text instructions. |
| | kwargs : dict |
| | Extra args passed to image processor. |
| | |
| | Returns |
| | ------- |
| | Dict[str, torch.Tensor] |
| | { |
| | "input_ids": [B, L], optional, |
| | "image_input": [B, num_views, C, H, W], optional, |
| | "image_mask": [B, num_views], optional |
| | } |
| | """ |
| | outputs: Dict[str, Any] = {} |
| |
|
| | |
| | if language_instruction is not None: |
| | outputs.update(self.encode_language(language_instruction)) |
| |
|
| | |
| | if images is not None: |
| | outputs.update(self.encode_image(images, **kwargs)) |
| |
|
| | |
| | if "input_ids" in outputs and "image_input" in outputs: |
| | assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), ( |
| | f"Batch mismatch: text batch {outputs['input_ids'].size(0)} " |
| | f"!= image batch {outputs['image_input'].size(0)}" |
| | ) |
| | return outputs |
| |
|