lingbot-vla / lingbotvla /data /data_transform.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from dataclasses import dataclass, field
from torch.utils.data._utils.collate import default_collate
import torch
from .data_collator import DataCollator
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from .chat_template import ChatTemplate
def split_into_chunks(sequence: Sequence[int], chunk_size: int) -> List[List[int]]:
"""
Splits a long sequence into chunks.
"""
total_len = len(sequence)
chunks = []
for i in range(0, total_len, chunk_size):
chunks.append(sequence[i : i + chunk_size])
return chunks
def process_pretrain_example(
example: Dict[str, Any],
tokenizer: "PreTrainedTokenizer",
max_seq_len: int,
text_keys: Union[str, List[str]] = "content_split",
source_name: Optional[str] = None,
) -> List[Dict[str, "torch.Tensor"]]:
examples = []
if isinstance(text_keys, str):
text_example = example[text_keys]
elif isinstance(text_keys, list):
for key in text_keys:
if key in example:
text_example = example[key]
break
else:
raise ValueError(f"None of the keys {text_keys} are found in the example.")
else:
raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}")
tokens = tokenizer.encode(text_example, add_special_tokens=False) + [tokenizer.eos_token_id]
for input_ids in split_into_chunks(tokens, max_seq_len):
examples.append(
{
"input_ids": torch.tensor(input_ids),
"attention_mask": torch.tensor([1] * len(input_ids)),
"labels": torch.tensor(input_ids),
}
)
return examples
def process_sft_example(
example: Dict[str, Any],
chat_template: "ChatTemplate",
max_seq_len: int,
text_keys: Union[str, List[str]] = "messages",
) -> List[Dict[str, "torch.Tensor"]]:
if isinstance(text_keys, str):
text_example = example[text_keys]
elif isinstance(text_keys, list):
for key in text_keys:
if key in example:
text_example = example[key]
break
else:
raise ValueError(f"None of the keys {text_keys} are found in the example.")
else:
raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}")
tokenized_example = chat_template.encode_messages(text_example, max_seq_len=max_seq_len)
tokenized_example = {k: torch.tensor(v) for k, v in tokenized_example.items()}
return [tokenized_example]
@dataclass
class VLADataCollatorWithPacking(DataCollator):
"""
Data collator to packing for omni dataset.
Args:
packing_features: features to packing in batch.
concat_features: features to concat in batch.
Example:
>>> from lingbotvla.data import OmniDataCollatorWithPacking
"""
state_features: List = field(
default_factory=lambda: [
"state",
"images",
"img_masks",
"lang_tokens",
"lang_masks",
"action_is_pad",
"actions",
"joint_mask",
"label",
"fast_mask"
],
metadata={"help": "state features with one chunk."},
)
def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
batch = {}
keys = {key for feature in features for key in feature.keys()}
for input_name in keys:
if input_name in self.state_features:
batch[input_name] = torch.cat(
[feature[input_name].unsqueeze(0) for feature in features if input_name in feature], dim=0
)
else:
batch[input_name] = default_collate(
[feature[input_name] for feature in features if input_name in feature]
)
return batch