| import time |
| import json |
| import os |
| import time |
| import argparse |
| import sys |
| import signal |
| import random |
| from multiprocessing import Process |
|
|
| import numpy as np |
| import tensorflow as tf |
| import yaml |
|
|
| from data.vla_dataset import VLADataset |
| from data.filelock import FileLock |
|
|
|
|
| |
| tf.config.set_visible_devices([], 'GPU') |
|
|
| |
| with open('configs/base.yaml', 'r') as file: |
| config = yaml.safe_load(file) |
| |
| BUF_PATH = config['dataset']['buf_path'] |
| BUF_NUM_CHUNKS = config['dataset']['buf_num_chunks'] |
| if BUF_NUM_CHUNKS < 1: |
| raise ValueError("Config `buf_num_chunks` must be at least 1.") |
| BUF_CHUNK_SIZE = config['dataset']['buf_chunk_size'] |
| if BUF_CHUNK_SIZE < 1: |
| raise ValueError("Config `buf_chunk_size` must be at least 1.") |
|
|
|
|
| def get_dirty_item(chunk_dir): |
| """ |
| Get indexes of dirty items in a chunk. |
| """ |
| dirty_bit = read_dirty_bit(chunk_dir) |
| return np.where(dirty_bit)[0].tolist() |
|
|
|
|
| def get_clean_item(chunk_dir): |
| """ |
| Get indexes of clean items in a chunk. |
| """ |
| dirty_bit = read_dirty_bit(chunk_dir) |
| return np.where(1 - dirty_bit)[0].tolist() |
|
|
|
|
| def save_dirty_bit(chunk_dir, dirty_bit): |
| """ |
| Save the dirty bit to the chunk directory. |
| """ |
| time_stmp = time.time() |
| while time.time() - time_stmp < 10.0: |
| try: |
| file_path = os.path.join(chunk_dir, "dirty_bit") |
| lock = FileLock(file_path) |
| lock.acquire_write_lock() |
| with open(file_path, 'wb') as file: |
| file.write(dirty_bit.tobytes()) |
| lock.release_lock() |
| return |
| except KeyboardInterrupt: |
| lock.release_lock() |
| raise KeyboardInterrupt |
| except BaseException: |
| lock.release_lock() |
| continue |
| |
| print("Failed to save dirty bit.") |
|
|
|
|
| def read_dirty_bit(chunk_dir): |
| """ |
| Read the dirty bit from the chunk directory. |
| """ |
| |
| time_stmp = time.time() |
| while time.time() - time_stmp < 10.0: |
| try: |
| file_path = os.path.join(chunk_dir, "dirty_bit") |
| lock = FileLock(file_path) |
| lock.acquire_read_lock() |
| with open(file_path, 'rb') as file: |
| dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy() |
| lock.release_lock() |
| assert len(dirty_bit) == BUF_CHUNK_SIZE |
| return dirty_bit |
| except KeyboardInterrupt: |
| lock.release_lock() |
| raise KeyboardInterrupt |
| except BaseException: |
| lock.release_lock() |
| continue |
| |
| return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8) |
|
|
|
|
| def save_sample(step_dict, chunk_dir, chunk_item_idx): |
| """ |
| Save a sample to the chunk directory. |
| """ |
| |
| time_stmp = time.time() |
| while time.time() - time_stmp < 10.0: |
| try: |
| locks = [] |
| json_content = step_dict['json_content'] |
| file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json") |
| lock = FileLock(file_path) |
| locks.append(lock) |
| lock.acquire_write_lock() |
| with open(file_path, 'w') as file: |
| json.dump(json_content, file, indent=4) |
| lock.release_lock() |
| |
| file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz") |
| lock = FileLock(file_path) |
| locks.append(lock) |
| lock.acquire_write_lock() |
| with open(file_path, 'wb') as file: |
| np.savez( |
| file, |
| step_id=step_dict['step_id'].numpy(), |
| state_chunk=step_dict['state_chunk'].numpy(), |
| state_chunk_time_mask=step_dict['state_chunk_time_mask'].numpy(), |
| action_chunk=step_dict['action_chunk'].numpy(), |
| action_chunk_time_mask=step_dict['action_chunk_time_mask'].numpy(), |
| state_vec_mask=step_dict['state_vec_mask'].numpy(), |
| past_frames_0=step_dict['past_frames_0'].numpy(), |
| past_frames_0_time_mask=step_dict['past_frames_0_time_mask'].numpy(), |
| past_frames_1=step_dict['past_frames_1'].numpy(), |
| past_frames_1_time_mask=step_dict['past_frames_1_time_mask'].numpy(), |
| past_frames_2=step_dict['past_frames_2'].numpy(), |
| past_frames_2_time_mask=step_dict['past_frames_2_time_mask'].numpy(), |
| past_frames_3=step_dict['past_frames_3'].numpy(), |
| past_frames_3_time_mask=step_dict['past_frames_3_time_mask'].numpy(), |
| state_std=step_dict['state_std'].numpy(), |
| state_mean=step_dict['state_mean'].numpy(), |
| state_norm=step_dict['state_norm'].numpy(), |
| ) |
| lock.release_lock() |
| return |
| except KeyboardInterrupt: |
| for lock in locks: |
| lock.release_lock() |
| raise KeyboardInterrupt |
| except BaseException: |
| for lock in locks: |
| lock.release_lock() |
| continue |
| |
| print("Failed to save sample.") |
|
|
|
|
| def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type): |
| """ |
| Run the producer. |
| The producer will first fill up the buffer with samples. |
| Then it will keep replacing dirty samples |
| (i.e., samples that have been read by the consumer) |
| with new samples. |
| """ |
| vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type) |
| chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers |
| chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers |
| if fill_up: |
| print(f"Worker {worker_id}: Start filling up the buffer...") |
| elif clean_dirty: |
| |
| print(f"Worker {worker_id}: Start refreshing the dirty bits...") |
| for chunk_idx in range(chunk_start_idx, chunk_end_idx): |
| chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}") |
| dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) |
| save_dirty_bit(chunk_dir, dirty_bit) |
| print(f"Worker {worker_id}: Refreshed the dirty bits.") |
|
|
| fill_chunk_idx = chunk_start_idx |
| fill_chunk_item_idx = 0 |
| dirty_chunk_idx = chunk_start_idx |
| dirty_chunk_item_idxs = [] |
| time_stmp = time.time() |
| for episode_steps in vla_dataset: |
| for step in episode_steps: |
| if fill_up and fill_chunk_idx < chunk_end_idx: |
| |
| chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}") |
| if fill_chunk_item_idx == 0: |
| |
| os.makedirs(chunk_dir, exist_ok=True) |
| |
| dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) |
| save_dirty_bit(chunk_dir, dirty_bit) |
| |
| |
| save_sample(step, chunk_dir, fill_chunk_item_idx) |
|
|
| |
| local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx |
| local_num_chunks = chunk_end_idx - chunk_start_idx |
| if (local_fill_chunk_idx % 10 == 0 or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0: |
| print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}") |
| fill_chunk_item_idx += 1 |
| if fill_chunk_item_idx == BUF_CHUNK_SIZE: |
| fill_chunk_idx += 1 |
| fill_chunk_item_idx = 0 |
| if fill_chunk_idx == BUF_NUM_CHUNKS: |
| print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...") |
|
|
| else: |
| |
| while len(dirty_chunk_item_idxs) == 0: |
| dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}") |
| dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir) |
| |
| if time.time() - time_stmp > 2.0: |
| dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE |
| print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}") |
| time_stmp = time.time() |
|
|
| if len(dirty_chunk_item_idxs) > 0: |
| |
| dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8) |
| save_dirty_bit(dirty_chunk_dir, dirty_bit) |
| |
| |
| dirty_chunk_idx += 1 |
| if dirty_chunk_idx == chunk_end_idx: |
| dirty_chunk_idx = chunk_start_idx |
|
|
| |
| dirty_item_idx = dirty_chunk_item_idxs.pop() |
| chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}") |
| |
| save_sample(step, chunk_dir, dirty_item_idx) |
|
|
| |
| if len(dirty_chunk_item_idxs) == 0: |
| |
| dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) |
| save_dirty_bit(dirty_chunk_dir, dirty_bit) |
| print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.") |
|
|
|
|
| if __name__ == '__main__': |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--n_workers', type=int, default=2, help="Number of parallel workers. It should be less than or equal to the number of chunks.") |
| parser.add_argument('--fill_up', action='store_true', help="Whether to fill up the buffer before replacing dirty samples.") |
| parser.add_argument('--clean_dirty', action='store_true', help="Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.") |
| parser.add_argument('--seed', type=int, default=None, help="Random seed. If not set, the seed will be randomly generated.") |
| parser.add_argument('--dataset_type', type=str, |
| default="pretrain", |
| help="Whether to load the pretrain dataset or finetune dataset.") |
| |
| |
| args = parser.parse_args() |
| if args.seed is not None: |
| print(f"Base seed: {args.seed}") |
| random.seed(args.seed) |
| |
| processes = [] |
| process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)] |
| print(f"Process seeds: {process_seeds}") |
| def signal_handler(sig, frame): |
| print("Ctrl+C received. Terminating child processes...") |
| for p in processes: |
| p.terminate() |
| sys.exit(0) |
| signal.signal(signal.SIGINT, signal_handler) |
| for worker_id in range(args.n_workers): |
| p = Process(target=run_producer, args=( |
| process_seeds[worker_id], args.n_workers, worker_id, args.fill_up, args.clean_dirty, args.dataset_type)) |
| p.start() |
| processes.append(p) |
|
|
| for p in processes: |
| p.join() |
|
|