| import json |
| import numpy as np |
| import os |
| import re |
| import time |
| from pathlib import Path |
| from dataclasses import asdict, dataclass, field |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional |
|
|
| import torch |
| from tqdm import trange, tqdm |
| from torch.utils.data import DataLoader |
| from lingbotvla.models import build_processor |
| from lingbotvla.utils import helper |
| from lingbotvla.utils.arguments import DataArguments, ModelArguments, TrainingArguments, parse_args |
| import lingbotvla.utils.normalize as normalize |
| from lingbotvla.data.vla_data.base_dataset import VlaDataset |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers import ProcessorMixin |
|
|
| from lingbotvla.data.chat_template import ChatTemplate |
|
|
| logger = helper.create_logger(__name__) |
|
|
|
|
| @dataclass |
| class MyDataArguments(DataArguments): |
| norm_path: str = field( |
| default=None, |
| metadata={"help": "Path to save norm stats."}, |
| ) |
| chunk_size: int = field( |
| default=50, |
| metadata={"help": "Chunk size of action."}, |
| ) |
|
|
|
|
| @dataclass |
| class Arguments: |
| model: "ModelArguments" = field(default_factory=ModelArguments) |
| data: "MyDataArguments" = field(default_factory=MyDataArguments) |
| train: "TrainingArguments" = field(default_factory=TrainingArguments) |
|
|
| def compute_norm(dataset, task_id, batch_size, stats, state_norm_keys, acton_norm_keys, delta_norm): |
| data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=False, drop_last=True) |
| success = True |
| try: |
| for batch in tqdm(data_loader, desc=f"Computing stats of {task_id}"): |
| for key in state_norm_keys: |
| values = np.asarray(batch[key]) |
| |
| stats[key].update(values.reshape(-1, values.shape[-1])) |
| for key in acton_norm_keys: |
| values = np.asarray(batch[key][:,0]) if not delta_norm[key] else np.asarray(batch[key].reshape(batch[key].shape[0], -1)) |
| stats[key].update(values.reshape(-1, values.shape[-1])) |
| except: success = False |
| return success |
|
|
|
|
| def main(): |
| args = parse_args(Arguments) |
| logger.info(f"Process rank: {args.train.global_rank}, world size: {args.train.world_size}") |
| logger.info_rank0(json.dumps(asdict(args), indent=2)) |
| |
| logger.info_rank0("Prepare data") |
| stats = None |
| |
| assert args.data.datasets_type == 'vla' |
| dataset = VlaDataset(repo_id=args.data.train_path, action_name='action') |
|
|
| state_norm_keys = ['observation.state'] |
| acton_norm_keys = ['action'] |
| delta_norm = {'action': False} |
| stats = {key: normalize.RunningStats() for key in acton_norm_keys+state_norm_keys} |
| |
| chunk_size = args.data.chunk_size |
|
|
| try: |
| success = compute_norm(dataset, args.data.train_path, args.train.global_batch_size, stats, state_norm_keys, acton_norm_keys, delta_norm) |
| except Exception as e: |
| fail_info = f"{args.data.train_path} {e}" |
| print(fail_info) |
| |
|
|
|
|
| if success: |
| norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} |
| norm_stats = {} |
| for key, stats in stats.items(): |
| if key in delta_norm and delta_norm[key]==True: |
| norm_stats[key] = stats.get_statistics(chunk_size=chunk_size) |
| else: |
| norm_stats[key] = stats.get_statistics() |
|
|
| output_path = Path(args.data.norm_path) |
| print(f"Writing stats to: {output_path}") |
| normalize.save(output_path, norm_stats, stats._count) |
| |
|
|
|
|
| if __name__ == "__main__": |
| main() |