| import sys |
| import os |
|
|
| ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) |
| sys.path.append(ROOT_DIR) |
| os.chdir(ROOT_DIR) |
|
|
| import numpy as np |
| import time |
| from diffusion_policy.common.timestamp_accumulator import ( |
| get_accumulate_timestamp_idxs, |
| TimestampObsAccumulator, |
| TimestampActionAccumulator |
| ) |
|
|
|
|
| def test_index(): |
| buffer = np.zeros(16) |
| start_time = 0.0 |
| dt = 1/10 |
|
|
| timestamps = np.linspace(0,1,100) |
| gi = list() |
| next_global_idx = 0 |
|
|
| local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps, |
| start_time=start_time, dt=dt, next_global_idx=next_global_idx) |
| assert local_idxs[0] == 0 |
| assert global_idxs[0] == 0 |
| |
| |
| |
| buffer[global_idxs] = timestamps[local_idxs] |
| gi.extend(global_idxs) |
| |
| timestamps = np.linspace(0.5,1.5,100) |
| local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps, |
| start_time=start_time, dt=dt, next_global_idx = next_global_idx) |
| |
| |
| |
| |
| buffer[global_idxs] = timestamps[local_idxs] |
| gi.extend(global_idxs) |
| |
| assert np.all(buffer[1:] > buffer[:-1]) |
| assert np.all(np.array(gi) == np.array(list(range(len(gi))))) |
| |
|
|
| |
| next_global_idx = 0 |
| timestamps = np.linspace(0,1,3) |
| local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps, |
| start_time=start_time, dt=dt, next_global_idx = next_global_idx) |
| assert local_idxs[0] == 0 |
| assert local_idxs[-1] == 2 |
| |
| |
| |
|
|
| |
| |
| start_time = time.time() |
| next_global_idx = 0 |
| timestamps = np.arange(100000) * dt + start_time |
| local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps, |
| start_time=start_time, dt=dt, next_global_idx = next_global_idx) |
| assert local_idxs == global_idxs |
| |
| |
| |
|
|
|
|
| def test_obs_accumulator(): |
| dt = 1/10 |
| ddt = 1/100 |
| n = 100 |
| d = 6 |
| start_time = time.time() |
| toa = TimestampObsAccumulator(start_time, dt) |
| poses = np.arange(n).reshape((n,1)) |
| poses = np.repeat(poses, d, axis=1) |
| timestamps = np.arange(n) * ddt + start_time |
|
|
| toa.put({ |
| 'pose': poses, |
| 'timestamp': timestamps |
| }, timestamps) |
| assert np.all(toa.data['pose'][:,0] == np.arange(10)*10) |
| assert len(toa) == 10 |
|
|
| |
| toa.put({ |
| 'pose': poses, |
| 'timestamp': timestamps |
| }, timestamps) |
| assert np.all(toa.data['pose'][:,0] == np.arange(10)*10) |
| assert len(toa) == 10 |
| |
| |
| dt = 1/10 |
| ddt = 1/5 |
| n = 10 |
| d = 6 |
| start_time = time.time() |
| toa = TimestampObsAccumulator(start_time, dt) |
| poses = np.arange(n).reshape((n,1)) |
| poses = np.repeat(poses, d, axis=1) |
| timestamps = np.arange(n) * ddt + start_time |
|
|
| toa.put({ |
| 'pose': poses, |
| 'timestamp': timestamps |
| }, timestamps) |
| assert len(toa) == 1 + (n-1) * 2 |
|
|
| timestamps = (np.arange(n) + 2) * ddt + start_time |
| toa.put({ |
| 'pose': poses, |
| 'timestamp': timestamps |
| }, timestamps) |
| assert len(toa) == 1 + (n-1) * 2 + 4 |
|
|
|
|
| def test_action_accumulator(): |
| dt = 1/10 |
| n = 10 |
| d = 6 |
| start_time = time.time() |
| taa = TimestampActionAccumulator(start_time, dt) |
| actions = np.arange(n).reshape((n,1)) |
| actions = np.repeat(actions, d, axis=1) |
|
|
| timestamps = np.arange(n) * dt + start_time |
| taa.put(actions, timestamps) |
| assert np.all(taa.actions == actions) |
| assert np.all(taa.timestamps == timestamps) |
|
|
| |
| taa.put(actions-5, timestamps-0.5) |
| assert np.allclose(taa.timestamps, timestamps) |
|
|
| |
| taa.put(actions+5, timestamps+0.5) |
| assert len(taa) == 15 |
| assert np.all(taa.actions[:,0] == np.arange(15)) |
| |
|
|
|
|
| if __name__ == '__main__': |
| test_action_accumulator() |
|
|