| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union |
| from dataclasses import dataclass, field |
| from torch.utils.data._utils.collate import default_collate |
| import torch |
| from .data_collator import DataCollator |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers import PreTrainedTokenizer |
|
|
| from .chat_template import ChatTemplate |
|
|
|
|
| def split_into_chunks(sequence: Sequence[int], chunk_size: int) -> List[List[int]]: |
| """ |
| Splits a long sequence into chunks. |
| """ |
| total_len = len(sequence) |
| chunks = [] |
| for i in range(0, total_len, chunk_size): |
| chunks.append(sequence[i : i + chunk_size]) |
|
|
| return chunks |
|
|
|
|
| def process_pretrain_example( |
| example: Dict[str, Any], |
| tokenizer: "PreTrainedTokenizer", |
| max_seq_len: int, |
| text_keys: Union[str, List[str]] = "content_split", |
| source_name: Optional[str] = None, |
| ) -> List[Dict[str, "torch.Tensor"]]: |
| examples = [] |
| if isinstance(text_keys, str): |
| text_example = example[text_keys] |
| elif isinstance(text_keys, list): |
| for key in text_keys: |
| if key in example: |
| text_example = example[key] |
| break |
| else: |
| raise ValueError(f"None of the keys {text_keys} are found in the example.") |
| else: |
| raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}") |
|
|
| tokens = tokenizer.encode(text_example, add_special_tokens=False) + [tokenizer.eos_token_id] |
| for input_ids in split_into_chunks(tokens, max_seq_len): |
| examples.append( |
| { |
| "input_ids": torch.tensor(input_ids), |
| "attention_mask": torch.tensor([1] * len(input_ids)), |
| "labels": torch.tensor(input_ids), |
| } |
| ) |
|
|
| return examples |
|
|
|
|
| def process_sft_example( |
| example: Dict[str, Any], |
| chat_template: "ChatTemplate", |
| max_seq_len: int, |
| text_keys: Union[str, List[str]] = "messages", |
| ) -> List[Dict[str, "torch.Tensor"]]: |
| if isinstance(text_keys, str): |
| text_example = example[text_keys] |
| elif isinstance(text_keys, list): |
| for key in text_keys: |
| if key in example: |
| text_example = example[key] |
| break |
| else: |
| raise ValueError(f"None of the keys {text_keys} are found in the example.") |
| else: |
| raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}") |
|
|
| tokenized_example = chat_template.encode_messages(text_example, max_seq_len=max_seq_len) |
| tokenized_example = {k: torch.tensor(v) for k, v in tokenized_example.items()} |
| return [tokenized_example] |
|
|
|
|
| @dataclass |
| class VLADataCollatorWithPacking(DataCollator): |
| """ |
| Data collator to packing for omni dataset. |
| Args: |
| packing_features: features to packing in batch. |
| concat_features: features to concat in batch. |
| Example: |
| >>> from lingbotvla.data import OmniDataCollatorWithPacking |
| """ |
| state_features: List = field( |
| default_factory=lambda: [ |
| "state", |
| "images", |
| "img_masks", |
| "lang_tokens", |
| "lang_masks", |
| "action_is_pad", |
| "actions", |
| "joint_mask", |
| "label", |
| "fast_mask" |
| ], |
| metadata={"help": "state features with one chunk."}, |
| ) |
|
|
| def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]: |
| batch = {} |
| keys = {key for feature in features for key in feature.keys()} |
| for input_name in keys: |
| if input_name in self.state_features: |
| batch[input_name] = torch.cat( |
| [feature[input_name].unsqueeze(0) for feature in features if input_name in feature], dim=0 |
| ) |
| else: |
| batch[input_name] = default_collate( |
| [feature[input_name] for feature in features if input_name in feature] |
| ) |
|
|
| return batch |