lingbot-vla / lingbotvla /data /vla_data /base_dataset.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2026 Robbyant Team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, Dict, List, Literal, Optional
import numpy as np
import torch
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from torch.utils.data import Dataset, IterableDataset
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from torchvision.transforms.v2 import Resize
from transformers import AutoTokenizer, AutoImageProcessor
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
import json
import yaml
from PIL import Image
from .transform import Normalizer, prepare_action, prepare_images, prepare_language, prepare_state
from ...utils import logging
class VlaDataset(Dataset):
def __init__(
self,
repo_id="path2dataset",
config=PI0Config,
tokenizer=AutoTokenizer,
data_config=None,
image_processor=None,
use_depth_align=False,
action_name="action",
):
self.image_processor = image_processor
# [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
# The timestamps are set to 0 for the images and state, as we only use current obs.
self.config = config
self.tokenizer = tokenizer
self.dataset_meta = LeRobotDatasetMetadata(repo_id)
delta_timestamps = {
action_name: [t / self.dataset_meta.fps for t in range(50)],
}
self.dataset = LeRobotDataset(
repo_id=repo_id,
delta_timestamps=delta_timestamps,
)
self.action_name = action_name
def __len__(self):
return len(self.dataset)
def getdata(self, idx):
item = self.dataset[idx]
task = self.dataset_meta.tasks[int(item['task_index'])]
assert task == item['task']
return item
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx < 0 or idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
max_retries = 200
attempts = 0
cur = idx
last_err = None
while attempts < max_retries:
try:
return self.getdata(cur)
except Exception as e:
last_err = e
attempts += 1
cur = np.random.randint(0, len(self))
if cur >= len(self):
cur = 0
continue
raise RuntimeError(
f"Failed to fetch a valid item starting from idx={idx} after {attempts} attempts. "
f"Last error: {repr(last_err)}"
)
class liberoDataset(Dataset):
def __init__(
self,
repo_id="libero",
config=PI0Config,
tokenizer=AutoTokenizer,
data_config=None,
image_processor=None,
use_depth_align=False,
):
image_transforms = Resize((data_config.img_size, data_config.img_size))
self.image_processor = image_processor
# [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
# The timestamps are set to 0 for the images and state, as we only use current obs.
self.config = config
self.tokenizer = tokenizer
self.norm_stats_file = data_config.norm_stats_file
self.dataset_meta = LeRobotDatasetMetadata(repo_id)
delta_timestamps = {
"actions": [t / self.dataset_meta.fps for t in range(50)],
}
self.dataset = LeRobotDataset(
repo_id=repo_id,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
)
with open(self.norm_stats_file) as f:
self.norm_stats = json.load(f)
self.normalizer = Normalizer(
# norm_stats=self.dataset.meta.stats,
norm_stats=self.norm_stats['norm_stats'],
from_file=True,
data_type='libero',
norm_type={
"image": "identity",
"wrist_image": "identity",
"state": data_config.norm_type,
"actions": data_config.norm_type,
},
)
self.use_depth_align = use_depth_align
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
task = self.dataset_meta.tasks[int(item['task_index'])]
assert task == item['task']
normalized_item = self.normalizer.normalize(item)
base_image = (normalized_item["image"] * 255).to(torch.uint8)
wrist_image = (normalized_item["wrist_image"] * 255).to(
torch.uint8
)
batch_dict = {
"image": {"base_0_rgb": base_image, "left_wrist_0_rgb": wrist_image},
"state": normalized_item["state"].to(torch.float32),
"action": normalized_item["actions"].to(torch.float32),
"action_is_pad": normalized_item["actions_is_pad"],
"prompt": [item["task"]],
}
state = prepare_state(self.config, batch_dict) # bs,8 -> bs,32
lang_tokens, lang_masks = prepare_language(self.config, self.tokenizer, batch_dict) # bs, seq_len
actions = prepare_action(self.config, batch_dict) # bs,50,7 -> bs,50,32 , 7
images, img_masks, pil_images = prepare_images(self.config, self.image_processor, batch_dict, use_depth_align=self.use_depth_align)
batch_dict = {
'images': images,
'img_masks': img_masks,
'state': state,
'lang_tokens': lang_tokens,
'lang_masks': lang_masks,
'actions': actions,
'action_is_pad': batch_dict['action_is_pad'],
}
if self.use_depth_align: batch_dict['pil_images'] = pil_images
return batch_dict
class RobotwinDataset(Dataset):
def __init__(
self,
repo_id="robotwin",
config=PI0Config,
tokenizer=AutoTokenizer,
data_config=None,
image_processor=None,
use_depth_align=False,
):
image_transforms = Resize((data_config.img_size, data_config.img_size))
self.image_processor = image_processor
# [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
# The timestamps are set to 0 for the images and state, as we only use current obs.
self.config = config
self.tokenizer = tokenizer
self.norm_stats_file = data_config.norm_stats_file
self.dataset_meta = LeRobotDatasetMetadata(repo_id)
delta_timestamps = {
"action": [t / self.dataset_meta.fps for t in range(50)],
}
self.dataset = LeRobotDataset(
repo_id=repo_id,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
)
with open(self.norm_stats_file) as f:
self.norm_stats = json.load(f)
self.normalizer = Normalizer(
# norm_stats=self.dataset.meta.stats,
norm_stats=self.norm_stats['norm_stats'],
from_file=True,
data_type='robotwin',
norm_type={
"observation.images.cam_high": "identity",
"observation.images.cam_left_wrist": "identity",
"observation.images.cam_right_wrist": "identity",
"observation.state": data_config.norm_type,
"action": data_config.norm_type,
},
)
self.use_depth_align = use_depth_align
def __len__(self):
return len(self.dataset)
def getdata(self, idx):
item = self.dataset[idx]
task = self.dataset_meta.tasks[int(item['task_index'])]
assert task == item['task']
normalized_item = self.normalizer.normalize(item)
base_image = (normalized_item["observation.images.cam_high"] * 255).to(torch.uint8)
left_wrist_image = (normalized_item["observation.images.cam_left_wrist"] * 255).to(
torch.uint8
)
right_wrist_image = (normalized_item["observation.images.cam_right_wrist"] * 255).to(
torch.uint8
)
batch_dict = {
"image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
"state": normalized_item["observation.state"].to(torch.float32),
"action": normalized_item["action"].to(torch.float32),
"action_is_pad": normalized_item["action_is_pad"],
"prompt": [item["task"]],
}
state = prepare_state(self.config, batch_dict) # bs,8 -> bs,32
lang_tokens, lang_masks = prepare_language(self.config, self.tokenizer, batch_dict) # bs, seq_len
actions = prepare_action(self.config, batch_dict) # bs,50,7 -> bs,50,32 , 7
images, img_masks, pil_images = prepare_images(self.config, self.image_processor, batch_dict, use_depth_align=self.use_depth_align)
batch_dict = {
'images': images,
'img_masks': img_masks,
'state': state,
'lang_tokens': lang_tokens,
'lang_masks': lang_masks,
'actions': actions,
'action_is_pad': batch_dict['action_is_pad'],
}
if self.use_depth_align: batch_dict['pil_images'] = pil_images
return batch_dict
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx < 0 or idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
max_retries = 200
attempts = 0
cur = idx
last_err = None
while attempts < max_retries:
try:
return self.getdata(cur)
except Exception as e:
last_err = e
attempts += 1
cur = np.random.randint(0, len(self))
if cur >= len(self):
cur = 0
continue
raise RuntimeError(
f"Failed to fetch a valid item starting from idx={idx} after {attempts} attempts. "
f"Last error: {repr(last_err)}"
)
class CustomizedRobotwinDataset(Dataset):
def __init__(
self,
repo_id="robotwin",
config=PI0Config,
tokenizer=AutoTokenizer,
data_config=None,
image_processor=None,
use_depth_align=False,
):
image_transforms = Resize((data_config.img_size, data_config.img_size))
self.image_processor = image_processor
# [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
# The timestamps are set to 0 for the images and state, as we only use current obs.
self.config = config
self.tokenizer = tokenizer
self.norm_stats_file = data_config.norm_stats_file
self.dataset_meta = LeRobotDatasetMetadata(repo_id)
delta_timestamps = {
"action": [t / self.dataset_meta.fps for t in range(50)],
}
self.dataset = LeRobotDataset(
repo_id=repo_id,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
)
with open(self.norm_stats_file) as f:
self.norm_stats = json.load(f)
self.normalizer = Normalizer(
# norm_stats=self.dataset.meta.stats,
norm_stats=self.norm_stats['norm_stats'],
from_file=True,
data_type='customized',
norm_type={
"observation.images.cam_high": "identity",
"observation.images.cam_left_wrist": "identity",
"observation.images.cam_right_wrist": "identity",
"observation.state": data_config.norm_type,
"action": data_config.norm_type,
},
)
self.use_depth_align = use_depth_align
def __len__(self):
return len(self.dataset)
def getdata(self, idx):
item = self.dataset[idx]
task = self.dataset_meta.tasks[int(item['task_index'])]
assert task == item['task']
normalized_item = self.normalizer.normalize(item)
base_image = (normalized_item["observation.images.cam_high"] * 255).to(torch.uint8)
left_wrist_image = (normalized_item["observation.images.cam_left_wrist"] * 255).to(
torch.uint8
)
right_wrist_image = (normalized_item["observation.images.cam_right_wrist"] * 255).to(
torch.uint8
)
batch_dict = {
"image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
"state": normalized_item["observation.state"].to(torch.float32),
"action": normalized_item["action"].to(torch.float32),
"action_is_pad": normalized_item["action_is_pad"],
"prompt": [item["task"]],
}
state = prepare_state(self.config, batch_dict) # bs,8 -> bs,32
lang_tokens, lang_masks = prepare_language(self.config, self.tokenizer, batch_dict) # bs, seq_len
actions = prepare_action(self.config, batch_dict) # bs,50,7 -> bs,50,32 , 7
images, img_masks, pil_images = prepare_images(self.config, self.image_processor, batch_dict, use_depth_align=self.use_depth_align)
batch_dict = {
'images': images,
'img_masks': img_masks,
'state': state,
'lang_tokens': lang_tokens,
'lang_masks': lang_masks,
'actions': actions,
'action_is_pad': batch_dict['action_is_pad'],
}
if self.use_depth_align: batch_dict['pil_images'] = pil_images
return batch_dict
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx < 0 or idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
max_retries = 200
attempts = 0
cur = idx
last_err = None
while attempts < max_retries:
try:
return self.getdata(cur)
except Exception as e:
last_err = e
attempts += 1
cur = np.random.randint(0, len(self))
if cur >= len(self):
cur = 0
continue
raise RuntimeError(
f"Failed to fetch a valid item starting from idx={idx} after {attempts} attempts. "
f"Last error: {repr(last_err)}"
)