lniki's picture
add model
0e83290 verified
import torch
import warnings
from XMem2.inference.kv_memory_store import KeyValueMemoryStore
from XMem2.model.memory_util import *
class MemoryManager:
"""
Manages all three memory stores and the transition between working/long-term memory
"""
def __init__(self, config):
self.config = config
self.hidden_dim = config['hidden_dim']
self.top_k = config['top_k']
self.enable_long_term = config['enable_long_term']
self.enable_long_term_usage = config['enable_long_term_count_usage']
if self.enable_long_term:
self.max_mt_frames = config[
'max_mid_term_frames'
] # maximum work memory size
self.min_mt_frames = config[
'min_mid_term_frames'
] # minimum number of frames to keep in work memory when consolidating
self.num_prototypes = config['num_prototypes']
self.max_long_elements = config['max_long_term_elements']
# dimensions will be inferred from input later
self.CK = self.CV = None
self.H = self.W = None
# The hidden state will be stored in a single tensor for all objects
# B x num_objects x CH x H x W
self.hidden = None
self.temporary_work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
self.permanent_work_mem = KeyValueMemoryStore(count_usage=False)
self.frame_id_to_permanent_mem_idx = dict()
if self.enable_long_term:
self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
self.reset_config = True
def update_config(self, config):
self.reset_config = True
self.hidden_dim = config['hidden_dim']
self.top_k = config['top_k']
assert self.enable_long_term == config['enable_long_term'], 'cannot update this'
assert (
self.enable_long_term_usage == config['enable_long_term_count_usage']
), 'cannot update this'
self.enable_long_term_usage = config['enable_long_term_count_usage']
if self.enable_long_term:
self.max_mt_frames = config['max_mid_term_frames']
self.min_mt_frames = config['min_mid_term_frames']
self.num_prototypes = config['num_prototypes']
self.max_long_elements = config['max_long_term_elements']
def _readout(self, affinity, v):
# this function is for a single object group
return v @ affinity
def match_memory(self, query_key, selection, disable_usage_updates=False):
# query_key: B x C^k x H x W
# selection: B x C^k x H x W
# TODO: keep groups in both..?
# 1x64x30x54
# = permanent_work_mem.num_groups, since it's always >= temporary_work_mem.num_groups
num_groups = max(
self.temporary_work_mem.num_groups, self.permanent_work_mem.num_groups
)
h, w = query_key.shape[-2:]
query_key = query_key.flatten(start_dim=2)
selection = selection.flatten(start_dim=2) if selection is not None else None
"""
Memory readout using keys
"""
temp_work_mem_size = self.temporary_work_mem.size
if self.enable_long_term and self.long_mem.engaged():
# Use long-term memory
long_mem_size = self.long_mem.size
memory_key = torch.cat(
[
self.long_mem.key,
self.temporary_work_mem.key,
self.permanent_work_mem.key,
],
-1,
)
shrinkage = torch.cat(
[
self.long_mem.shrinkage,
self.temporary_work_mem.shrinkage,
self.permanent_work_mem.shrinkage,
],
-1,
)
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
long_mem_similarity = similarity[:, :long_mem_size]
temp_work_mem_similarity = similarity[
:, long_mem_size : long_mem_size + temp_work_mem_size
]
perm_work_mem_similarity = similarity[
:, long_mem_size + temp_work_mem_size :
]
# get the usage with the first group
# the first group always have all the keys valid
affinity, usage = do_softmax(
torch.cat(
[
long_mem_similarity[:, -self.long_mem.get_v_size(0) :],
temp_work_mem_similarity,
perm_work_mem_similarity,
],
1,
),
top_k=self.top_k,
inplace=True,
return_usage=True,
)
affinity = [affinity]
# compute affinity group by group as later groups only have a subset of keys
for gi in range(1, num_groups):
temp_group_v_size = self.temporary_work_mem.get_v_size(gi)
perm_group_v_size = self.permanent_work_mem.get_v_size(gi)
temp_sim_size = temp_work_mem_similarity.shape[1]
perm_sim_size = perm_work_mem_similarity.shape[1]
if gi < self.long_mem.num_groups:
# merge working and lt similarities before softmax
affinity_one_group = do_softmax(
torch.cat(
[
long_mem_similarity[:, -self.long_mem.get_v_size(gi) :],
temp_work_mem_similarity[
:, temp_sim_size - temp_group_v_size :
],
perm_work_mem_similarity[
:, perm_sim_size - perm_group_v_size :
],
],
dim=1,
),
top_k=self.top_k,
inplace=True,
)
else:
# no long-term memory for this group
affinity_one_group = do_softmax(
torch.cat(
[
temp_work_mem_similarity[
:, temp_sim_size - temp_group_v_size :
],
perm_work_mem_similarity[
:, perm_sim_size - perm_group_v_size :
],
],
1,
),
top_k=self.top_k,
inplace=(gi == num_groups - 1),
)
affinity.append(affinity_one_group)
all_memory_value = []
for gi in range(num_groups):
# merge the working and lt values before readout
if gi < self.long_mem.num_groups:
all_memory_value.append(
torch.cat(
[
self.long_mem.value[gi],
self.temporary_work_mem.value[gi],
self.permanent_work_mem.value[gi],
],
-1,
)
)
else:
all_memory_value.append(
torch.cat(
[
self.temporary_work_mem.value[gi],
self.permanent_work_mem.value[gi],
],
-1,
)
)
"""
Record memory usage for working and long-term memory
"""
if not disable_usage_updates:
# ignore the index return for long-term memory
work_usage = usage[
:, long_mem_size : long_mem_size + temp_work_mem_size
] # no usage for permanent memory
self.temporary_work_mem.update_usage(work_usage.flatten())
if self.enable_long_term_usage:
# ignore the index return for working memory
long_usage = usage[:, :long_mem_size]
self.long_mem.update_usage(long_usage.flatten())
else:
memory_key = torch.cat(
[self.temporary_work_mem.key, self.permanent_work_mem.key], -1
)
shrinkage = torch.cat(
[self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage],
-1,
)
# No long-term memory
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
temp_work_mem_similarity = similarity[:, :temp_work_mem_size]
perm_work_mem_similarity = similarity[:, temp_work_mem_size:]
if self.enable_long_term:
affinity, usage = do_softmax(
similarity,
inplace=(num_groups == 1),
top_k=self.top_k,
return_usage=True,
)
if not disable_usage_updates:
# Record memory usage for working memory
self.temporary_work_mem.update_usage(
usage[:, :temp_work_mem_size].flatten()
)
else:
affinity = do_softmax(
similarity,
inplace=(num_groups == 1),
top_k=self.top_k,
return_usage=False,
)
affinity = [affinity]
# compute affinity group by group as later groups only have a subset of keys
for gi in range(1, num_groups):
temp_group_v_size = self.temporary_work_mem.get_v_size(gi)
perm_group_v_size = self.permanent_work_mem.get_v_size(gi)
temp_sim_size = temp_work_mem_similarity.shape[1]
perm_sim_size = perm_work_mem_similarity.shape[1]
affinity_one_group = do_softmax(
torch.cat(
[
# concats empty tensor if the group is also empty for temporary memory
temp_work_mem_similarity[
:, temp_sim_size - temp_group_v_size :
],
perm_work_mem_similarity[
:, perm_sim_size - perm_group_v_size :
],
],
dim=1,
),
top_k=self.top_k,
inplace=(gi == num_groups - 1),
)
affinity.append(affinity_one_group)
all_memory_value = []
for gi in range(num_groups):
group_v_cat = torch.cat(
[
self.temporary_work_mem.value[gi],
self.permanent_work_mem.value[gi],
],
-1,
)
all_memory_value.append(group_v_cat)
# Shared affinity within each group
all_readout_mem = torch.cat(
[self._readout(affinity[gi], gv) for gi, gv in enumerate(all_memory_value)],
0,
)
return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
def update_permanent_memory(self, frame_idx, key, shrinkage, value, selection=None):
saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx]
key = key.flatten(start_dim=2)
shrinkage = shrinkage.flatten(start_dim=2)
value = value[0].flatten(start_dim=2)
if selection is not None:
selection = selection.flatten(start_dim=2)
self.permanent_work_mem.replace_at(saved_pos, key, value, shrinkage, selection)
def remove_from_permanent_memory(self, frame_idx):
elem_size = self.HW
saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx]
self.permanent_work_mem.remove_at(saved_pos, elem_size)
del self.frame_id_to_permanent_mem_idx[frame_idx]
def add_memory(
self,
key,
shrinkage,
value,
objects,
selection=None,
permanent=False,
ignore=False,
ti=None,
):
# key: 1*C*H*W
# value: 1*num_objects*C*H*W
# objects contain a list of object indices
if self.H is None or self.reset_config:
self.reset_config = False
self.H, self.W = key.shape[-2:]
self.HW = self.H * self.W
if self.enable_long_term:
# convert from num. frames to num. nodes
self.min_work_elements = self.min_mt_frames * self.HW
self.max_work_elements = self.max_mt_frames * self.HW
# key: 1*C*N
# value: num_objects*C*N
key = key.flatten(start_dim=2)
shrinkage = shrinkage.flatten(start_dim=2)
value = value[0].flatten(start_dim=2)
self.CK = key.shape[1]
self.CV = value.shape[1]
if selection is not None:
if not self.enable_long_term:
warnings.warn(
'the selection factor is only needed in long-term mode', UserWarning
)
selection = selection.flatten(start_dim=2)
if ignore:
pass # all permanent frames are pre-placed into permanent memory (when using our memory modification)
# also ignores the first frame (#0) when using original memory mechanism, since it's already in the permanent memory
elif permanent:
pos = self.permanent_work_mem.add(key, value, shrinkage, selection, objects)
if ti is not None:
self.frame_id_to_permanent_mem_idx[ti] = pos
else:
self.temporary_work_mem.add(key, value, shrinkage, selection, objects)
num_temp_groups = self.temporary_work_mem.num_groups
num_perm_groups = self.permanent_work_mem.num_groups
if not self.temporary_work_mem.engaged() or (
num_temp_groups != num_perm_groups
):
# print(f"PERM_NUM_GROUPS={num_perm_groups} vs TEMP_NUM_GROUPS={num_temp_groups}", end=' ')
# first frame or new group; we need to have both memories engaged to avoid crashes when concating
# so we just initialize the temporary one with an empty tensor
key0 = key[..., 0:0]
value0 = value[..., 0:0]
shrinkage0 = shrinkage[..., 0:0]
selection0 = selection[..., 0:0]
if num_perm_groups > num_temp_groups:
# for preloading into permanent memory
self.temporary_work_mem.add(
key0, value0, shrinkage0, selection0, objects
)
else:
# for original memory mechanism
self.permanent_work_mem.add(
key0, value0, shrinkage0, selection0, objects
)
# print(f"AFTER->PERM_NUM_GROUPS={self.permanent_work_mem.num_groups} vs TEMP_NUM_GROUPS={self.temporary_work_mem.num_groups}")
# long-term memory cleanup
if self.enable_long_term:
# Do memory compressed if needed
if self.temporary_work_mem.size >= self.max_work_elements:
# if we have more then N elements in the work memory
# Remove obsolete features if needed
if self.long_mem.size >= (self.max_long_elements - self.num_prototypes):
self.long_mem.remove_obsolete_features(
self.max_long_elements - self.num_prototypes
)
# We NEVER remove anything from the working memory
self.compress_features()
def create_hidden_state(self, n, sample_key):
# n is the TOTAL number of objects
h, w = sample_key.shape[-2:]
if self.hidden is None:
self.hidden = torch.zeros(
(1, n, self.hidden_dim, h, w), device=sample_key.device
)
elif self.hidden.shape[1] != n:
self.hidden = torch.cat(
[
self.hidden,
torch.zeros(
(1, n - self.hidden.shape[1], self.hidden_dim, h, w),
device=sample_key.device,
),
],
1,
)
assert self.hidden.shape[1] == n
def set_hidden(self, hidden):
self.hidden = hidden
def get_hidden(self):
return self.hidden
def frame_already_saved(self, ti):
return ti in self.frame_id_to_permanent_mem_idx
# def slices_excluding_permanent(self, group_value, start, end):
# HW = self.HW
# group_value[:,:,HW:-self.min_work_elements+HW]
# slices = []
# # this won't work because after just 1 consolidation all permanent frames are going to be god know where
# # and their indices would mean nothing
# # How about have 2 separate tensors and concatenate them just for memory reading?
# all_indices = torch.arange(self.temporary_work_mem.size // HW) # all frames indices from 0 to ...
def compress_features(self):
HW = self.HW
candidate_value = []
total_work_mem_size = self.temporary_work_mem.size
for gv in self.temporary_work_mem.value:
# Some object groups might be added later in the video
# So not all keys have values associated with all objects
# We need to keep track of the key->value validity
mem_size_in_this_group = gv.shape[-1]
if mem_size_in_this_group == total_work_mem_size:
# full LT
candidate_value.append(gv[:, :, : -self.min_work_elements])
else:
# mem_size is smaller than total_work_mem_size, but at least HW
assert HW <= mem_size_in_this_group < total_work_mem_size
if mem_size_in_this_group > self.min_work_elements:
# part of this object group still goes into LT
candidate_value.append(gv[:, :, : -self.min_work_elements])
else:
# this object group cannot go to the LT at all
candidate_value.append(None)
# perform memory consolidation
# now starts at zero, because the 1st frame is going into permanent memory
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
*self.temporary_work_mem.get_all_sliced(0, -self.min_work_elements),
candidate_value
)
# remove consolidated working memory
self.temporary_work_mem.sieve_by_range(
0, -self.min_work_elements, min_size=self.min_work_elements + HW
)
# add to long-term memory
self.long_mem.add(
prototype_key,
prototype_value,
prototype_shrinkage,
selection=None,
objects=None,
)
def consolidation(
self,
candidate_key,
candidate_shrinkage,
candidate_selection,
usage,
candidate_value,
):
# keys: 1*C*N
# values: num_objects*C*N
N = candidate_key.shape[-1]
# find the indices with max usage
_, max_usage_indices = torch.topk(
usage, k=self.num_prototypes, dim=-1, sorted=True
)
prototype_indices = max_usage_indices.flatten()
# Prototypes are invalid for out-of-bound groups
validity = [
prototype_indices >= (N - gv.shape[2]) if gv is not None else None
for gv in candidate_value
]
prototype_key = candidate_key[:, :, prototype_indices]
prototype_selection = (
candidate_selection[:, :, prototype_indices]
if candidate_selection is not None
else None
)
"""
Potentiation step
"""
similarity = get_similarity(
candidate_key, candidate_shrinkage, prototype_key, prototype_selection
)
# convert similarity to affinity
# need to do it group by group since the softmax normalization would be different
affinity = [
(
do_softmax(similarity[:, -gv.shape[2] :, validity[gi]])
if gv is not None
else None
)
for gi, gv in enumerate(candidate_value)
]
# some values can be have all False validity. Weed them out.
affinity = [
aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
]
# readout the values
prototype_value = [
self._readout(affinity[gi], gv) if affinity[gi] is not None else None
for gi, gv in enumerate(candidate_value)
]
# readout the shrinkage term
prototype_shrinkage = (
self._readout(affinity[0], candidate_shrinkage)
if candidate_shrinkage is not None
else None
)
return prototype_key, prototype_value, prototype_shrinkage
def copy_perm_mem_only(self):
new_mem = MemoryManager(config=self.config)
if (
self.permanent_work_mem.key is None
or self.permanent_work_mem.key.size(-1) == 0
):
return new_mem
new_mem.permanent_work_mem = self.permanent_work_mem
new_mem.frame_id_to_permanent_mem_idx = self.frame_id_to_permanent_mem_idx
key0 = self.permanent_work_mem.key[..., 0:0]
value0 = self.permanent_work_mem.value[0][..., 0:0]
shrinkage0 = (
self.permanent_work_mem.shrinkage[..., 0:0]
if self.permanent_work_mem.shrinkage is not None
else None
)
selection0 = (
self.permanent_work_mem.selection[..., 0:0]
if self.permanent_work_mem.selection is not None
else None
)
new_mem.temporary_work_mem.add(
key0, value0, shrinkage0, selection0, self.permanent_work_mem.all_objects
)
new_mem.CK = self.permanent_work_mem.key.shape[1]
new_mem.CV = self.permanent_work_mem.value[0].shape[1]
key_shape = self.permanent_work_mem.key.shape
sample_key = self.permanent_work_mem.key[..., 0 : self.HW].view(
*key_shape[:-1], self.H, self.W
)
new_mem.create_hidden_state(
len(self.permanent_work_mem.all_objects), sample_key
)
new_mem.temporary_work_mem.obj_groups = self.temporary_work_mem.obj_groups
new_mem.temporary_work_mem.all_objects = self.temporary_work_mem.all_objects
new_mem.CK = self.CK
new_mem.CV = self.CV
new_mem.H = self.H
new_mem.W = self.W
new_mem.HW = self.HW
return new_mem