ATCTrack-VLM / lib /train /data /sampler.py
SunXiang2025's picture
Update: two-stage training, per-channel FiLM gate, cosine scheduler, 9B config
b3f019f verified
import random
import cv2
import torch.utils.data
from lib.utils import TensorDict
import numpy as np
# from pytorch_pretrained_bert import BertTokenizer
import os
def no_processing(data):
return data
class TrackingSampler(torch.utils.data.Dataset):
""" Class responsible for sampling frames from training sequences to form batches.
The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and
'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and
(base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled.
If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found.
The sampled frames are then passed through the input 'processing' function for the necessary processing-
"""
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
num_search_frames, num_template_frames=1, processing=no_processing, frame_sample_mode='causal',
train_cls=False, pos_prob=0.5):
"""
args:
datasets - List of datasets to be used for training
p_datasets - List containing the probabilities by which each dataset will be sampled
samples_per_epoch - Number of training samples per epoch
max_gap - Maximum gap, in frame numbers, between the train frames and the test frames.
num_search_frames - Number of search frames to sample.
num_template_frames - Number of template frames to sample.
processing - An instance of Processing class which performs the necessary processing of the data.
frame_sample_mode - 'causal', 'interval', or 'order'.
train_cls - this is for Stark-ST, should be False for SeqTrack.
"""
self.datasets = datasets
self.train_cls = train_cls # whether we are training classification
self.pos_prob = pos_prob # probability of sampling positive class when making classification
# If p not provided, sample uniformly from all videos
if p_datasets is None:
p_datasets = [len(d) for d in self.datasets]
# Normalize
p_total = sum(p_datasets)
self.p_datasets = [x / p_total for x in p_datasets]
self.samples_per_epoch = samples_per_epoch
self.max_gap = max_gap
self.num_search_frames = num_search_frames
self.num_template_frames = num_template_frames
self.processing = processing
self.frame_sample_mode = frame_sample_mode
self.multi_modal_language = False
def __len__(self):
return self.samples_per_epoch
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None,
allow_invisible=False, force_invisible=False):
""" Samples num_ids frames between min_id and max_id for which target is visible
args:
visible - 1d Tensor indicating whether target is visible for each frame
num_ids - number of frames to be samples
min_id - Minimum allowed frame number
max_id - Maximum allowed frame number
returns:
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
"""
if num_ids == 0:
return []
if min_id is None or min_id < 0:
min_id = 0
if max_id is None or max_id > len(visible):
max_id = len(visible)
# get valid ids
if force_invisible:
valid_ids = [i for i in range(min_id, max_id) if not visible[i]]
else:
if allow_invisible:
valid_ids = [i for i in range(min_id, max_id)]
else:
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
# No visible ids
if len(valid_ids) == 0:
return None
return random.choices(valid_ids, k=num_ids)
def __getitem__(self, index):
if self.train_cls:
return self.getitem_cls()
else:
return self.getitem()
def getitem(self):
"""
returns:
TensorDict - dict containing all the data blocks
"""
valid = False
count_valid = 0
while not valid:
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
is_video_dataset = dataset.is_video_sequence()
# sample a sequence from the given dataset
try:
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
except Exception as e:
print(f"data sampler bug: {e}")
valid = False
continue
if is_video_dataset:
template_frame_ids = None
search_frame_ids = None
gap_increase = 0
if self.frame_sample_mode == 'causal':
# Sample test and train frames in a causal manner, i.e. search_frame_ids > template_frame_ids
while search_frame_ids is None:
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_template_frames - 1,
max_id=len(visible) - self.num_search_frames)
prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_template_frames - 1,
min_id=base_frame_id[0] - self.max_gap - gap_increase,
max_id=base_frame_id[0])
if prev_frame_ids is None:
gap_increase += 5
if gap_increase > 1000:
print("too large image gap, check the sampler, current gap: "+str(gap_increase))
continue
template_frame_ids = base_frame_id + prev_frame_ids
search_frame_ids = self._sample_visible_ids(visible, min_id=template_frame_ids[0] + 1,
max_id=template_frame_ids[0] + self.max_gap + gap_increase,
num_ids=self.num_search_frames)
# Increase gap until a frame is found
gap_increase += 5
if gap_increase > 1000:
print("too large image gap, check the sampler, current gap: " + str(gap_increase))
elif self.frame_sample_mode == "order":
template_frame_ids, search_frame_ids = self.get_frame_ids_order(visible)
elif self.frame_sample_mode == "trident" or self.frame_sample_mode == "trident_pro":
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
elif self.frame_sample_mode == "stark":
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
else:
raise ValueError("Illegal frame sample mode")
else:
# In case of image dataset, just repeat the image to generate synthetic video
template_frame_ids = [1] * self.num_template_frames
search_frame_ids = [1] * self.num_search_frames
try:
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
H, W, _ = template_frames[0].shape
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros((H, W))] * self.num_template_frames
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros((H, W))] * self.num_search_frames
data = TensorDict({'template_images': template_frames,
'template_anno': template_anno['bbox'],
'template_masks': template_masks,
'template_frame_ids': torch.tensor(template_frame_ids, dtype=torch.long),
'search_images': search_frames,
'search_anno': search_anno['bbox'],
'search_masks': search_masks,
'search_frame_ids': torch.tensor(search_frame_ids, dtype=torch.long),
'dataset': dataset.get_name(),
'seq_name': meta_obj_test.get('seq_name', ''),
'test_class': meta_obj_test.get('object_class_name')})
# # tokenize language
# if self.multi_modal_language:
# # nlp = template_anno['nlp'][0]
# nlp = template_anno.get("nlp", None)
# if nlp is not None:
# nlp = nlp[0]
# nlp_token_ids, nlp_token_masks = self.extract_token_from_nlp(nlp, self.max_query_len)
# data['nl_token_ids'] = nlp_token_ids
# data['nl_token_masks'] = nlp_token_masks
data["nlp"] = template_anno["nlp"][0]
# make data augmentation
data = self.processing(data)
# check whether data is valid
valid = data['valid']
except Exception as e:
print(f"data sampler bug: {e}")
valid = False
count_valid += 1
if count_valid > 200:
print("too large count_valid, check the sampler, current count_valid: "+str(count_valid))
# self.show(data, 'template', 0, 'rgb')
# self.show(data, 'template', 0, 'dte')
# self.show(data, 'template', 1, 'rgb')
# self.show(data, 'template', 1, 'dte')
# self.show(data, 'search', 0, 'rgb')
# self.show(data, 'search', 0, 'dte')
return data
def show(self, data, strr, i, modality):
image = data[strr+'_images'][i]
if modality == 'rgb':
image = image[:3,:,:]
else:
image = image[3:, :, :]
_, H, W = image.shape
import cv2
x1, y1, w, h = data[strr+'_anno'][i]
x1, y1, w, h = int(x1*W), int(y1*H), int(w*W), int(h*H)
image_show = image.permute(1,2,0).numpy()
max = image_show.max()
min = image_show.min()
image_show = (image_show-min) * 255 / (max-min)
image_show = np.ascontiguousarray(image_show.astype('uint8'))
cv2.rectangle(image_show, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color=(0, 0, 255), thickness=2)
cv2.imshow(strr+str(i)+modality, image_show)
if cv2.waitKey() & 0xFF == ord('q'):
pass
def getitem_cls(self):
# get data for classification
"""
args:
index (int): Index (Ignored since we sample randomly)
aux (bool): whether the current data is for auxiliary use (e.g. copy-and-paste)
returns:
TensorDict - dict containing all the data blocks
"""
valid = False
label = None
while not valid:
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
is_video_dataset = dataset.is_video_sequence()
# sample a sequence from the given dataset
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
# sample template and search frame ids
if is_video_dataset:
if self.frame_sample_mode in ["trident", "trident_pro"]:
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
elif self.frame_sample_mode == "stark":
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
else:
raise ValueError("illegal frame sample mode")
else:
# In case of image dataset, just repeat the image to generate synthetic video
template_frame_ids = [1] * self.num_template_frames
search_frame_ids = [1] * self.num_search_frames
try:
# "try" is used to handle trackingnet data failure
# get images and bounding boxes (for templates)
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
seq_info_dict)
H, W, _ = template_frames[0].shape
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros(
(H, W))] * self.num_template_frames
# get images and bounding boxes (for searches)
# positive samples
if random.random() < self.pos_prob:
label = torch.ones(1,)
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
(H, W))] * self.num_search_frames
# negative samples
else:
label = torch.zeros(1,)
if is_video_dataset:
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
if search_frame_ids is None:
search_frames, search_anno, meta_obj_test = self.get_one_search()
else:
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
seq_info_dict)
search_anno["bbox"] = [self.get_center_box(H, W)]
else:
search_frames, search_anno, meta_obj_test = self.get_one_search()
H, W, _ = search_frames[0].shape
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
(H, W))] * self.num_search_frames
data = TensorDict({'template_images': template_frames,
'template_anno': template_anno['bbox'],
'template_masks': template_masks,
'template_frame_ids': torch.tensor(template_frame_ids, dtype=torch.long),
'search_images': search_frames,
'search_anno': search_anno['bbox'],
'search_masks': search_masks,
'search_frame_ids': torch.tensor(search_frame_ids, dtype=torch.long),
'dataset': dataset.get_name(),
'test_class': meta_obj_test.get('object_class_name')})
# make data augmentation
data = self.processing(data)
# add classification label
data["label"] = label
# check whether data is valid
valid = data['valid']
except:
valid = False
return data
def get_center_box(self, H, W, ratio=1/8):
cx, cy, w, h = W/2, H/2, W * ratio, H * ratio
return torch.tensor([int(cx-w/2), int(cy-h/2), int(w), int(h)])
def sample_seq_from_dataset(self, dataset, is_video_dataset):
# Sample a sequence with enough visible frames
enough_visible_frames = False
#add by chenxin to debug
count = 0
while not enough_visible_frames:
# Sample a sequence
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
# Sample frames
seq_info_dict = dataset.get_sequence_info(seq_id)
visible = seq_info_dict['visible']
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
self.num_search_frames + self.num_template_frames) and len(visible) >= 20
enough_visible_frames = enough_visible_frames or not is_video_dataset
count += 1
if count > 200:
print("too large count, check the sampler, current count: " + str(count))
return seq_id, visible, seq_info_dict
def get_one_search(self):
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
is_video_dataset = dataset.is_video_sequence()
# sample a sequence
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
# sample a frame
if is_video_dataset:
if self.frame_sample_mode == "stark":
search_frame_ids = self._sample_visible_ids(seq_info_dict["valid"], num_ids=1)
else:
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, allow_invisible=True)
else:
search_frame_ids = [1]
# get the image, bounding box and other info
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
return search_frames, search_anno, meta_obj_test
def get_frame_ids_trident(self, visible):
# get template and search ids in a 'trident' manner
template_frame_ids_extra = []
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
template_frame_ids_extra = []
# first randomly sample two frames from a video
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
# get the dynamic template id
for max_gap in self.max_gap:
if template_frame_id1[0] >= search_frame_ids[0]:
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
else:
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
if self.frame_sample_mode == "trident_pro":
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id,
allow_invisible=True)
else:
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id)
if f_id is None:
template_frame_ids_extra += [None]
else:
template_frame_ids_extra += f_id
template_frame_ids = template_frame_id1 + template_frame_ids_extra
return template_frame_ids, search_frame_ids
def get_frame_ids_stark(self, visible, valid):
# get template and search ids in a 'stark' manner
template_frame_ids_extra = []
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
template_frame_ids_extra = []
# first randomly sample two frames from a video
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
# get the dynamic template id
for max_gap in self.max_gap:
if template_frame_id1[0] >= search_frame_ids[0]:
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
else:
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
"""we require the frame to be valid but not necessary visible"""
f_id = self._sample_visible_ids(valid, num_ids=1, min_id=min_id, max_id=max_id)
if f_id is None:
template_frame_ids_extra += [None]
else:
template_frame_ids_extra += f_id
template_frame_ids = template_frame_id1 + template_frame_ids_extra
return template_frame_ids, search_frame_ids
def get_frame_ids_order(self, visible):
# get template and search ids in an 'order' manner, the template and search regions are arranged in chronological order
frame_ids = []
gap_increase = 0
while (None in frame_ids) or (len(frame_ids)==0):
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=0,
max_id=len(visible))
frame_ids = self._sample_visible_ids(visible, num_ids=self.num_template_frames+self.num_search_frames,
min_id=base_frame_id[0] - self.max_gap - gap_increase,
max_id=base_frame_id[0] + self.max_gap + gap_increase)
if (frame_ids is None) or (None in frame_ids):
gap_increase += 5
if gap_increase > 1000:
print("too large image gap, check the sampler, current gap: " + str(gap_increase))
continue
if torch.rand(1) < 0.5:
frame_ids.sort(reverse=True)
template_frame_ids = frame_ids[0:self.num_template_frames]
search_frame_ids = frame_ids[self.num_template_frames:]
else:
frame_ids.sort(reverse=False)
template_frame_ids = frame_ids[0:self.num_template_frames]
search_frame_ids = frame_ids[self.num_template_frames:]
# Increase gap until a frame is found
gap_increase += 5
if gap_increase > 1000:
print("too large image gap, check the sampler, current gap: " + str(gap_increase))
return template_frame_ids, search_frame_ids
def extract_token_from_nlp(self, nlp, seq_length):
""" use tokenizer to convert nlp to tokens
param:
nlp: a sentence of natural language
seq_length: the max token length, if token length larger than seq_len then cut it,
elif less than, append '0' token at the reef.
return:
token_ids and token_marks
"""
nlp_token = self.tokenizer.tokenize(nlp)
if len(nlp_token) > seq_length - 2:
nlp_token = nlp_token[0:(seq_length - 2)]
# build tokens and token_ids
tokens = []
input_type_ids = []
tokens.append("[CLS]")
input_type_ids.append(0)
for token in nlp_token:
tokens.append(token)
input_type_ids.append(0)
tokens.append("[SEP]")
input_type_ids.append(0)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < seq_length:
input_ids.append(0)
input_mask.append(0)
input_type_ids.append(0)
assert len(input_ids) == seq_length
assert len(input_mask) == seq_length
assert len(input_type_ids) == seq_length
return input_ids, input_mask