# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from dataclasses import dataclass, field from torch.utils.data._utils.collate import default_collate import torch from .data_collator import DataCollator if TYPE_CHECKING: from transformers import PreTrainedTokenizer from .chat_template import ChatTemplate def split_into_chunks(sequence: Sequence[int], chunk_size: int) -> List[List[int]]: """ Splits a long sequence into chunks. """ total_len = len(sequence) chunks = [] for i in range(0, total_len, chunk_size): chunks.append(sequence[i : i + chunk_size]) return chunks def process_pretrain_example( example: Dict[str, Any], tokenizer: "PreTrainedTokenizer", max_seq_len: int, text_keys: Union[str, List[str]] = "content_split", source_name: Optional[str] = None, ) -> List[Dict[str, "torch.Tensor"]]: examples = [] if isinstance(text_keys, str): text_example = example[text_keys] elif isinstance(text_keys, list): for key in text_keys: if key in example: text_example = example[key] break else: raise ValueError(f"None of the keys {text_keys} are found in the example.") else: raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}") tokens = tokenizer.encode(text_example, add_special_tokens=False) + [tokenizer.eos_token_id] for input_ids in split_into_chunks(tokens, max_seq_len): examples.append( { "input_ids": torch.tensor(input_ids), "attention_mask": torch.tensor([1] * len(input_ids)), "labels": torch.tensor(input_ids), } ) return examples def process_sft_example( example: Dict[str, Any], chat_template: "ChatTemplate", max_seq_len: int, text_keys: Union[str, List[str]] = "messages", ) -> List[Dict[str, "torch.Tensor"]]: if isinstance(text_keys, str): text_example = example[text_keys] elif isinstance(text_keys, list): for key in text_keys: if key in example: text_example = example[key] break else: raise ValueError(f"None of the keys {text_keys} are found in the example.") else: raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}") tokenized_example = chat_template.encode_messages(text_example, max_seq_len=max_seq_len) tokenized_example = {k: torch.tensor(v) for k, v in tokenized_example.items()} return [tokenized_example] @dataclass class VLADataCollatorWithPacking(DataCollator): """ Data collator to packing for omni dataset. Args: packing_features: features to packing in batch. concat_features: features to concat in batch. Example: >>> from lingbotvla.data import OmniDataCollatorWithPacking """ state_features: List = field( default_factory=lambda: [ "state", "images", "img_masks", "lang_tokens", "lang_masks", "action_is_pad", "actions", "joint_mask", "label", "fast_mask" ], metadata={"help": "state features with one chunk."}, ) def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]: batch = {} keys = {key for feature in features for key in feature.keys()} for input_name in keys: if input_name in self.state_features: batch[input_name] = torch.cat( [feature[input_name].unsqueeze(0) for feature in features if input_name in feature], dim=0 ) else: batch[input_name] = default_collate( [feature[input_name] for feature in features if input_name in feature] ) return batch