|
|
import torch
|
|
|
import warnings
|
|
|
|
|
|
from XMem2.inference.kv_memory_store import KeyValueMemoryStore
|
|
|
from XMem2.model.memory_util import *
|
|
|
|
|
|
|
|
|
class MemoryManager:
|
|
|
"""
|
|
|
Manages all three memory stores and the transition between working/long-term memory
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config):
|
|
|
self.config = config
|
|
|
self.hidden_dim = config['hidden_dim']
|
|
|
self.top_k = config['top_k']
|
|
|
|
|
|
self.enable_long_term = config['enable_long_term']
|
|
|
self.enable_long_term_usage = config['enable_long_term_count_usage']
|
|
|
if self.enable_long_term:
|
|
|
self.max_mt_frames = config[
|
|
|
'max_mid_term_frames'
|
|
|
]
|
|
|
self.min_mt_frames = config[
|
|
|
'min_mid_term_frames'
|
|
|
]
|
|
|
self.num_prototypes = config['num_prototypes']
|
|
|
self.max_long_elements = config['max_long_term_elements']
|
|
|
|
|
|
|
|
|
self.CK = self.CV = None
|
|
|
self.H = self.W = None
|
|
|
|
|
|
|
|
|
|
|
|
self.hidden = None
|
|
|
|
|
|
self.temporary_work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
|
|
|
self.permanent_work_mem = KeyValueMemoryStore(count_usage=False)
|
|
|
self.frame_id_to_permanent_mem_idx = dict()
|
|
|
if self.enable_long_term:
|
|
|
self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
|
|
|
|
|
|
self.reset_config = True
|
|
|
|
|
|
def update_config(self, config):
|
|
|
self.reset_config = True
|
|
|
self.hidden_dim = config['hidden_dim']
|
|
|
self.top_k = config['top_k']
|
|
|
|
|
|
assert self.enable_long_term == config['enable_long_term'], 'cannot update this'
|
|
|
assert (
|
|
|
self.enable_long_term_usage == config['enable_long_term_count_usage']
|
|
|
), 'cannot update this'
|
|
|
|
|
|
self.enable_long_term_usage = config['enable_long_term_count_usage']
|
|
|
if self.enable_long_term:
|
|
|
self.max_mt_frames = config['max_mid_term_frames']
|
|
|
self.min_mt_frames = config['min_mid_term_frames']
|
|
|
self.num_prototypes = config['num_prototypes']
|
|
|
self.max_long_elements = config['max_long_term_elements']
|
|
|
|
|
|
def _readout(self, affinity, v):
|
|
|
|
|
|
return v @ affinity
|
|
|
|
|
|
def match_memory(self, query_key, selection, disable_usage_updates=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_groups = max(
|
|
|
self.temporary_work_mem.num_groups, self.permanent_work_mem.num_groups
|
|
|
)
|
|
|
h, w = query_key.shape[-2:]
|
|
|
|
|
|
query_key = query_key.flatten(start_dim=2)
|
|
|
selection = selection.flatten(start_dim=2) if selection is not None else None
|
|
|
|
|
|
"""
|
|
|
Memory readout using keys
|
|
|
"""
|
|
|
|
|
|
temp_work_mem_size = self.temporary_work_mem.size
|
|
|
if self.enable_long_term and self.long_mem.engaged():
|
|
|
|
|
|
long_mem_size = self.long_mem.size
|
|
|
|
|
|
memory_key = torch.cat(
|
|
|
[
|
|
|
self.long_mem.key,
|
|
|
self.temporary_work_mem.key,
|
|
|
self.permanent_work_mem.key,
|
|
|
],
|
|
|
-1,
|
|
|
)
|
|
|
shrinkage = torch.cat(
|
|
|
[
|
|
|
self.long_mem.shrinkage,
|
|
|
self.temporary_work_mem.shrinkage,
|
|
|
self.permanent_work_mem.shrinkage,
|
|
|
],
|
|
|
-1,
|
|
|
)
|
|
|
|
|
|
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
|
|
|
|
|
|
long_mem_similarity = similarity[:, :long_mem_size]
|
|
|
temp_work_mem_similarity = similarity[
|
|
|
:, long_mem_size : long_mem_size + temp_work_mem_size
|
|
|
]
|
|
|
perm_work_mem_similarity = similarity[
|
|
|
:, long_mem_size + temp_work_mem_size :
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
affinity, usage = do_softmax(
|
|
|
torch.cat(
|
|
|
[
|
|
|
long_mem_similarity[:, -self.long_mem.get_v_size(0) :],
|
|
|
temp_work_mem_similarity,
|
|
|
perm_work_mem_similarity,
|
|
|
],
|
|
|
1,
|
|
|
),
|
|
|
top_k=self.top_k,
|
|
|
inplace=True,
|
|
|
return_usage=True,
|
|
|
)
|
|
|
affinity = [affinity]
|
|
|
|
|
|
|
|
|
for gi in range(1, num_groups):
|
|
|
temp_group_v_size = self.temporary_work_mem.get_v_size(gi)
|
|
|
perm_group_v_size = self.permanent_work_mem.get_v_size(gi)
|
|
|
temp_sim_size = temp_work_mem_similarity.shape[1]
|
|
|
perm_sim_size = perm_work_mem_similarity.shape[1]
|
|
|
|
|
|
if gi < self.long_mem.num_groups:
|
|
|
|
|
|
affinity_one_group = do_softmax(
|
|
|
torch.cat(
|
|
|
[
|
|
|
long_mem_similarity[:, -self.long_mem.get_v_size(gi) :],
|
|
|
temp_work_mem_similarity[
|
|
|
:, temp_sim_size - temp_group_v_size :
|
|
|
],
|
|
|
perm_work_mem_similarity[
|
|
|
:, perm_sim_size - perm_group_v_size :
|
|
|
],
|
|
|
],
|
|
|
dim=1,
|
|
|
),
|
|
|
top_k=self.top_k,
|
|
|
inplace=True,
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
affinity_one_group = do_softmax(
|
|
|
torch.cat(
|
|
|
[
|
|
|
temp_work_mem_similarity[
|
|
|
:, temp_sim_size - temp_group_v_size :
|
|
|
],
|
|
|
perm_work_mem_similarity[
|
|
|
:, perm_sim_size - perm_group_v_size :
|
|
|
],
|
|
|
],
|
|
|
1,
|
|
|
),
|
|
|
top_k=self.top_k,
|
|
|
inplace=(gi == num_groups - 1),
|
|
|
)
|
|
|
affinity.append(affinity_one_group)
|
|
|
|
|
|
all_memory_value = []
|
|
|
for gi in range(num_groups):
|
|
|
|
|
|
if gi < self.long_mem.num_groups:
|
|
|
all_memory_value.append(
|
|
|
torch.cat(
|
|
|
[
|
|
|
self.long_mem.value[gi],
|
|
|
self.temporary_work_mem.value[gi],
|
|
|
self.permanent_work_mem.value[gi],
|
|
|
],
|
|
|
-1,
|
|
|
)
|
|
|
)
|
|
|
else:
|
|
|
all_memory_value.append(
|
|
|
torch.cat(
|
|
|
[
|
|
|
self.temporary_work_mem.value[gi],
|
|
|
self.permanent_work_mem.value[gi],
|
|
|
],
|
|
|
-1,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
"""
|
|
|
Record memory usage for working and long-term memory
|
|
|
"""
|
|
|
if not disable_usage_updates:
|
|
|
|
|
|
work_usage = usage[
|
|
|
:, long_mem_size : long_mem_size + temp_work_mem_size
|
|
|
]
|
|
|
self.temporary_work_mem.update_usage(work_usage.flatten())
|
|
|
|
|
|
if self.enable_long_term_usage:
|
|
|
|
|
|
long_usage = usage[:, :long_mem_size]
|
|
|
self.long_mem.update_usage(long_usage.flatten())
|
|
|
else:
|
|
|
memory_key = torch.cat(
|
|
|
[self.temporary_work_mem.key, self.permanent_work_mem.key], -1
|
|
|
)
|
|
|
shrinkage = torch.cat(
|
|
|
[self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage],
|
|
|
-1,
|
|
|
)
|
|
|
|
|
|
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
|
|
|
temp_work_mem_similarity = similarity[:, :temp_work_mem_size]
|
|
|
perm_work_mem_similarity = similarity[:, temp_work_mem_size:]
|
|
|
|
|
|
if self.enable_long_term:
|
|
|
affinity, usage = do_softmax(
|
|
|
similarity,
|
|
|
inplace=(num_groups == 1),
|
|
|
top_k=self.top_k,
|
|
|
return_usage=True,
|
|
|
)
|
|
|
if not disable_usage_updates:
|
|
|
|
|
|
self.temporary_work_mem.update_usage(
|
|
|
usage[:, :temp_work_mem_size].flatten()
|
|
|
)
|
|
|
else:
|
|
|
affinity = do_softmax(
|
|
|
similarity,
|
|
|
inplace=(num_groups == 1),
|
|
|
top_k=self.top_k,
|
|
|
return_usage=False,
|
|
|
)
|
|
|
|
|
|
affinity = [affinity]
|
|
|
|
|
|
|
|
|
for gi in range(1, num_groups):
|
|
|
temp_group_v_size = self.temporary_work_mem.get_v_size(gi)
|
|
|
perm_group_v_size = self.permanent_work_mem.get_v_size(gi)
|
|
|
temp_sim_size = temp_work_mem_similarity.shape[1]
|
|
|
perm_sim_size = perm_work_mem_similarity.shape[1]
|
|
|
|
|
|
affinity_one_group = do_softmax(
|
|
|
torch.cat(
|
|
|
[
|
|
|
|
|
|
temp_work_mem_similarity[
|
|
|
:, temp_sim_size - temp_group_v_size :
|
|
|
],
|
|
|
perm_work_mem_similarity[
|
|
|
:, perm_sim_size - perm_group_v_size :
|
|
|
],
|
|
|
],
|
|
|
dim=1,
|
|
|
),
|
|
|
top_k=self.top_k,
|
|
|
inplace=(gi == num_groups - 1),
|
|
|
)
|
|
|
affinity.append(affinity_one_group)
|
|
|
|
|
|
all_memory_value = []
|
|
|
for gi in range(num_groups):
|
|
|
group_v_cat = torch.cat(
|
|
|
[
|
|
|
self.temporary_work_mem.value[gi],
|
|
|
self.permanent_work_mem.value[gi],
|
|
|
],
|
|
|
-1,
|
|
|
)
|
|
|
all_memory_value.append(group_v_cat)
|
|
|
|
|
|
|
|
|
all_readout_mem = torch.cat(
|
|
|
[self._readout(affinity[gi], gv) for gi, gv in enumerate(all_memory_value)],
|
|
|
0,
|
|
|
)
|
|
|
|
|
|
return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
|
|
|
|
|
|
def update_permanent_memory(self, frame_idx, key, shrinkage, value, selection=None):
|
|
|
saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx]
|
|
|
|
|
|
key = key.flatten(start_dim=2)
|
|
|
shrinkage = shrinkage.flatten(start_dim=2)
|
|
|
value = value[0].flatten(start_dim=2)
|
|
|
|
|
|
if selection is not None:
|
|
|
selection = selection.flatten(start_dim=2)
|
|
|
|
|
|
self.permanent_work_mem.replace_at(saved_pos, key, value, shrinkage, selection)
|
|
|
|
|
|
def remove_from_permanent_memory(self, frame_idx):
|
|
|
elem_size = self.HW
|
|
|
saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx]
|
|
|
|
|
|
self.permanent_work_mem.remove_at(saved_pos, elem_size)
|
|
|
|
|
|
del self.frame_id_to_permanent_mem_idx[frame_idx]
|
|
|
|
|
|
def add_memory(
|
|
|
self,
|
|
|
key,
|
|
|
shrinkage,
|
|
|
value,
|
|
|
objects,
|
|
|
selection=None,
|
|
|
permanent=False,
|
|
|
ignore=False,
|
|
|
ti=None,
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
if self.H is None or self.reset_config:
|
|
|
self.reset_config = False
|
|
|
self.H, self.W = key.shape[-2:]
|
|
|
self.HW = self.H * self.W
|
|
|
if self.enable_long_term:
|
|
|
|
|
|
self.min_work_elements = self.min_mt_frames * self.HW
|
|
|
self.max_work_elements = self.max_mt_frames * self.HW
|
|
|
|
|
|
|
|
|
|
|
|
key = key.flatten(start_dim=2)
|
|
|
shrinkage = shrinkage.flatten(start_dim=2)
|
|
|
value = value[0].flatten(start_dim=2)
|
|
|
|
|
|
self.CK = key.shape[1]
|
|
|
self.CV = value.shape[1]
|
|
|
|
|
|
if selection is not None:
|
|
|
if not self.enable_long_term:
|
|
|
warnings.warn(
|
|
|
'the selection factor is only needed in long-term mode', UserWarning
|
|
|
)
|
|
|
selection = selection.flatten(start_dim=2)
|
|
|
|
|
|
if ignore:
|
|
|
pass
|
|
|
|
|
|
elif permanent:
|
|
|
pos = self.permanent_work_mem.add(key, value, shrinkage, selection, objects)
|
|
|
if ti is not None:
|
|
|
self.frame_id_to_permanent_mem_idx[ti] = pos
|
|
|
else:
|
|
|
self.temporary_work_mem.add(key, value, shrinkage, selection, objects)
|
|
|
|
|
|
num_temp_groups = self.temporary_work_mem.num_groups
|
|
|
num_perm_groups = self.permanent_work_mem.num_groups
|
|
|
|
|
|
if not self.temporary_work_mem.engaged() or (
|
|
|
num_temp_groups != num_perm_groups
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key0 = key[..., 0:0]
|
|
|
value0 = value[..., 0:0]
|
|
|
shrinkage0 = shrinkage[..., 0:0]
|
|
|
selection0 = selection[..., 0:0]
|
|
|
if num_perm_groups > num_temp_groups:
|
|
|
|
|
|
self.temporary_work_mem.add(
|
|
|
key0, value0, shrinkage0, selection0, objects
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
self.permanent_work_mem.add(
|
|
|
key0, value0, shrinkage0, selection0, objects
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.enable_long_term:
|
|
|
|
|
|
if self.temporary_work_mem.size >= self.max_work_elements:
|
|
|
|
|
|
|
|
|
if self.long_mem.size >= (self.max_long_elements - self.num_prototypes):
|
|
|
self.long_mem.remove_obsolete_features(
|
|
|
self.max_long_elements - self.num_prototypes
|
|
|
)
|
|
|
|
|
|
|
|
|
self.compress_features()
|
|
|
|
|
|
def create_hidden_state(self, n, sample_key):
|
|
|
|
|
|
h, w = sample_key.shape[-2:]
|
|
|
if self.hidden is None:
|
|
|
self.hidden = torch.zeros(
|
|
|
(1, n, self.hidden_dim, h, w), device=sample_key.device
|
|
|
)
|
|
|
elif self.hidden.shape[1] != n:
|
|
|
self.hidden = torch.cat(
|
|
|
[
|
|
|
self.hidden,
|
|
|
torch.zeros(
|
|
|
(1, n - self.hidden.shape[1], self.hidden_dim, h, w),
|
|
|
device=sample_key.device,
|
|
|
),
|
|
|
],
|
|
|
1,
|
|
|
)
|
|
|
|
|
|
assert self.hidden.shape[1] == n
|
|
|
|
|
|
def set_hidden(self, hidden):
|
|
|
self.hidden = hidden
|
|
|
|
|
|
def get_hidden(self):
|
|
|
return self.hidden
|
|
|
|
|
|
def frame_already_saved(self, ti):
|
|
|
return ti in self.frame_id_to_permanent_mem_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compress_features(self):
|
|
|
HW = self.HW
|
|
|
candidate_value = []
|
|
|
total_work_mem_size = self.temporary_work_mem.size
|
|
|
for gv in self.temporary_work_mem.value:
|
|
|
|
|
|
|
|
|
|
|
|
mem_size_in_this_group = gv.shape[-1]
|
|
|
if mem_size_in_this_group == total_work_mem_size:
|
|
|
|
|
|
candidate_value.append(gv[:, :, : -self.min_work_elements])
|
|
|
else:
|
|
|
|
|
|
assert HW <= mem_size_in_this_group < total_work_mem_size
|
|
|
if mem_size_in_this_group > self.min_work_elements:
|
|
|
|
|
|
candidate_value.append(gv[:, :, : -self.min_work_elements])
|
|
|
else:
|
|
|
|
|
|
candidate_value.append(None)
|
|
|
|
|
|
|
|
|
|
|
|
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
|
|
*self.temporary_work_mem.get_all_sliced(0, -self.min_work_elements),
|
|
|
candidate_value
|
|
|
)
|
|
|
|
|
|
|
|
|
self.temporary_work_mem.sieve_by_range(
|
|
|
0, -self.min_work_elements, min_size=self.min_work_elements + HW
|
|
|
)
|
|
|
|
|
|
|
|
|
self.long_mem.add(
|
|
|
prototype_key,
|
|
|
prototype_value,
|
|
|
prototype_shrinkage,
|
|
|
selection=None,
|
|
|
objects=None,
|
|
|
)
|
|
|
|
|
|
def consolidation(
|
|
|
self,
|
|
|
candidate_key,
|
|
|
candidate_shrinkage,
|
|
|
candidate_selection,
|
|
|
usage,
|
|
|
candidate_value,
|
|
|
):
|
|
|
|
|
|
|
|
|
N = candidate_key.shape[-1]
|
|
|
|
|
|
|
|
|
_, max_usage_indices = torch.topk(
|
|
|
usage, k=self.num_prototypes, dim=-1, sorted=True
|
|
|
)
|
|
|
prototype_indices = max_usage_indices.flatten()
|
|
|
|
|
|
|
|
|
validity = [
|
|
|
prototype_indices >= (N - gv.shape[2]) if gv is not None else None
|
|
|
for gv in candidate_value
|
|
|
]
|
|
|
|
|
|
prototype_key = candidate_key[:, :, prototype_indices]
|
|
|
prototype_selection = (
|
|
|
candidate_selection[:, :, prototype_indices]
|
|
|
if candidate_selection is not None
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
"""
|
|
|
Potentiation step
|
|
|
"""
|
|
|
similarity = get_similarity(
|
|
|
candidate_key, candidate_shrinkage, prototype_key, prototype_selection
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
affinity = [
|
|
|
(
|
|
|
do_softmax(similarity[:, -gv.shape[2] :, validity[gi]])
|
|
|
if gv is not None
|
|
|
else None
|
|
|
)
|
|
|
for gi, gv in enumerate(candidate_value)
|
|
|
]
|
|
|
|
|
|
|
|
|
affinity = [
|
|
|
aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
|
|
|
]
|
|
|
|
|
|
|
|
|
prototype_value = [
|
|
|
self._readout(affinity[gi], gv) if affinity[gi] is not None else None
|
|
|
for gi, gv in enumerate(candidate_value)
|
|
|
]
|
|
|
|
|
|
|
|
|
prototype_shrinkage = (
|
|
|
self._readout(affinity[0], candidate_shrinkage)
|
|
|
if candidate_shrinkage is not None
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
return prototype_key, prototype_value, prototype_shrinkage
|
|
|
|
|
|
def copy_perm_mem_only(self):
|
|
|
new_mem = MemoryManager(config=self.config)
|
|
|
|
|
|
if (
|
|
|
self.permanent_work_mem.key is None
|
|
|
or self.permanent_work_mem.key.size(-1) == 0
|
|
|
):
|
|
|
return new_mem
|
|
|
|
|
|
new_mem.permanent_work_mem = self.permanent_work_mem
|
|
|
new_mem.frame_id_to_permanent_mem_idx = self.frame_id_to_permanent_mem_idx
|
|
|
|
|
|
key0 = self.permanent_work_mem.key[..., 0:0]
|
|
|
value0 = self.permanent_work_mem.value[0][..., 0:0]
|
|
|
shrinkage0 = (
|
|
|
self.permanent_work_mem.shrinkage[..., 0:0]
|
|
|
if self.permanent_work_mem.shrinkage is not None
|
|
|
else None
|
|
|
)
|
|
|
selection0 = (
|
|
|
self.permanent_work_mem.selection[..., 0:0]
|
|
|
if self.permanent_work_mem.selection is not None
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
new_mem.temporary_work_mem.add(
|
|
|
key0, value0, shrinkage0, selection0, self.permanent_work_mem.all_objects
|
|
|
)
|
|
|
|
|
|
new_mem.CK = self.permanent_work_mem.key.shape[1]
|
|
|
new_mem.CV = self.permanent_work_mem.value[0].shape[1]
|
|
|
|
|
|
key_shape = self.permanent_work_mem.key.shape
|
|
|
sample_key = self.permanent_work_mem.key[..., 0 : self.HW].view(
|
|
|
*key_shape[:-1], self.H, self.W
|
|
|
)
|
|
|
new_mem.create_hidden_state(
|
|
|
len(self.permanent_work_mem.all_objects), sample_key
|
|
|
)
|
|
|
|
|
|
new_mem.temporary_work_mem.obj_groups = self.temporary_work_mem.obj_groups
|
|
|
new_mem.temporary_work_mem.all_objects = self.temporary_work_mem.all_objects
|
|
|
|
|
|
new_mem.CK = self.CK
|
|
|
new_mem.CV = self.CV
|
|
|
new_mem.H = self.H
|
|
|
new_mem.W = self.W
|
|
|
new_mem.HW = self.HW
|
|
|
|
|
|
return new_mem
|
|
|
|