File size: 8,161 Bytes
33c751d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | from typing import Dict, List, Union
from queue import Empty
import numbers
import time
from multiprocessing.managers import SharedMemoryManager
import numpy as np
from diffusion_policy.shared_memory.shared_ndarray import SharedNDArray
from diffusion_policy.shared_memory.shared_memory_util import ArraySpec, SharedAtomicCounter
class SharedMemoryRingBuffer:
"""
A Lock-Free FILO Shared Memory Data Structure.
Stores a sequence of dict of numpy arrays.
"""
def __init__(self,
shm_manager: SharedMemoryManager,
array_specs: List[ArraySpec],
get_max_k: int,
get_time_budget: float,
put_desired_frequency: float,
safety_margin: float=1.5
):
"""
shm_manager: Manages the life cycle of share memories
across processes. Remember to run .start() before passing.
array_specs: Name, shape and type of arrays for a single time step.
get_max_k: The maxmum number of items can be queried at once.
get_time_budget: The maxmum amount of time spent copying data from
shared memory to local memory. Increase this number for larger arrays.
put_desired_frequency: The maximum frequency that .put() can be called.
This influces the buffer size.
"""
# create atomic counter
counter = SharedAtomicCounter(shm_manager)
# compute buffer size
# At any given moment, the past get_max_k items should never
# be touched (to be read freely). Assuming the reading is reading
# these k items, which takes maximum of get_time_budget seconds,
# we need enough empty slots to make sure put_desired_frequency Hz
# of put can be sustaied.
buffer_size = int(np.ceil(
put_desired_frequency * get_time_budget
* safety_margin)) + get_max_k
# allocate shared memory
shared_arrays = dict()
for spec in array_specs:
key = spec.name
assert key not in shared_arrays
array = SharedNDArray.create_from_shape(
mem_mgr=shm_manager,
shape=(buffer_size,) + tuple(spec.shape),
dtype=spec.dtype)
shared_arrays[key] = array
# allocate timestamp array
timestamp_array = SharedNDArray.create_from_shape(
mem_mgr=shm_manager,
shape=(buffer_size,),
dtype=np.float64)
timestamp_array.get()[:] = -np.inf
self.buffer_size = buffer_size
self.array_specs = array_specs
self.counter = counter
self.shared_arrays = shared_arrays
self.timestamp_array = timestamp_array
self.get_time_budget = get_time_budget
self.get_max_k = get_max_k
self.put_desired_frequency = put_desired_frequency
@property
def count(self):
return self.counter.load()
@classmethod
def create_from_examples(cls,
shm_manager: SharedMemoryManager,
examples: Dict[str, Union[np.ndarray, numbers.Number]],
get_max_k: int=32,
get_time_budget: float=0.01,
put_desired_frequency: float=60
):
specs = list()
for key, value in examples.items():
shape = None
dtype = None
if isinstance(value, np.ndarray):
shape = value.shape
dtype = value.dtype
assert dtype != np.dtype('O')
elif isinstance(value, numbers.Number):
shape = tuple()
dtype = np.dtype(type(value))
else:
raise TypeError(f'Unsupported type {type(value)}')
spec = ArraySpec(
name=key,
shape=shape,
dtype=dtype
)
specs.append(spec)
obj = cls(
shm_manager=shm_manager,
array_specs=specs,
get_max_k=get_max_k,
get_time_budget=get_time_budget,
put_desired_frequency=put_desired_frequency
)
return obj
def clear(self):
self.counter.store(0)
def put(self, data: Dict[str, Union[np.ndarray, numbers.Number]], wait: bool=True):
count = self.counter.load()
next_idx = count % self.buffer_size
# Make sure the next self.get_max_k elements in the ring buffer have at least
# self.get_time_budget seconds untouched after written, so that
# get_last_k can safely read k elements from any count location.
# Sanity check: when get_max_k == 1, the element pointed by next_idx
# should be rewritten at minimum self.get_time_budget seconds later.
timestamp_lookahead_idx = (next_idx + self.get_max_k - 1) % self.buffer_size
old_timestamp = self.timestamp_array.get()[timestamp_lookahead_idx]
t = time.monotonic()
if (t - old_timestamp) < self.get_time_budget:
deltat = t - old_timestamp
if wait:
# sleep the remaining time to be safe
time.sleep(self.get_time_budget - deltat)
else:
# throw an error
past_iters = self.buffer_size - self.get_max_k
hz = past_iters / deltat
raise TimeoutError(
'Put executed too fast {}items/{:.4f}s ~= {}Hz'.format(
past_iters, deltat,hz))
# write to shared memory
for key, value in data.items():
arr: np.ndarray
arr = self.shared_arrays[key].get()
if isinstance(value, np.ndarray):
arr[next_idx] = value
else:
arr[next_idx] = np.array(value, dtype=arr.dtype)
# update timestamp
self.timestamp_array.get()[next_idx] = time.monotonic()
self.counter.add(1)
def _allocate_empty(self, k=None):
result = dict()
for spec in self.array_specs:
shape = spec.shape
if k is not None:
shape = (k,) + shape
result[spec.name] = np.empty(
shape=shape, dtype=spec.dtype)
return result
def get(self, out=None) -> Dict[str, np.ndarray]:
if out is None:
out = self._allocate_empty()
start_time = time.monotonic()
count = self.counter.load()
curr_idx = (count - 1) % self.buffer_size
for key, value in self.shared_arrays.items():
arr = value.get()
np.copyto(out[key], arr[curr_idx])
end_time = time.monotonic()
dt = end_time - start_time
if dt > self.get_time_budget:
raise TimeoutError(f'Get time out {dt} vs {self.get_time_budget}')
return out
def get_last_k(self, k:int, out=None) -> Dict[str, np.ndarray]:
assert k <= self.get_max_k
if out is None:
out = self._allocate_empty(k)
start_time = time.monotonic()
count = self.counter.load()
assert k <= count
curr_idx = (count - 1) % self.buffer_size
for key, value in self.shared_arrays.items():
arr = value.get()
target = out[key]
end = curr_idx + 1
start = max(0, end - k)
target_end = k
target_start = target_end - (end - start)
target[target_start: target_end] = arr[start:end]
remainder = k - (end - start)
if remainder > 0:
# wrap around
end = self.buffer_size
start = end - remainder
target_start = 0
target_end = end - start
target[target_start: target_end] = arr[start:end]
end_time = time.monotonic()
dt = end_time - start_time
if dt > self.get_time_budget:
raise TimeoutError(f'Get time out {dt} vs {self.get_time_budget}')
return out
def get_all(self) -> Dict[str, np.ndarray]:
k = min(self.count, self.get_max_k)
return self.get_last_k(k=k)
|