| import base64 |
| import copy |
| import json |
| import os |
| from io import BytesIO |
| from urllib import request as urlrequest |
|
|
| import pandas as pd |
|
|
| from lib.test.tracker.basetracker import BaseTracker |
| import torch |
| import torch.nn as nn |
| from lib.test.tracker.atctrack_utils import sample_target, transform_image_to_crop |
| import cv2 |
| from lib.utils.box_ops import box_xywh_to_xyxy, box_xyxy_to_cxcywh, box_cxcywh_to_xyxy |
| from lib.utils.misc import NestedTensor |
| from lib.models.atctrack import build_atctrack |
| from lib.test.tracker.atctrack_utils import Preprocessor |
| from lib.utils.box_ops import clip_box |
| import numpy as np |
| from lib.test.utils.hann import hann2d |
| from lib.utils.ce_utils import generate_mask_cond,generate_bbox_mask |
| from matplotlib import pyplot as plt |
| from PIL import Image, ImageDraw |
| |
|
|
| def get_resize_template_bbox(template_bbox, resize_factor ): |
| w,h = template_bbox[2] , template_bbox[3] |
| w_1, h_1 = int(w * resize_factor ) , int( h*resize_factor ) |
| xc, yc = 64, 64 |
|
|
| x0,y0 = int( xc - w_1*0.5 ) , int( yc - h_1*0.5 ) |
|
|
| resize_template_bbox = [x0,y0,w_1,h_1] |
|
|
| return resize_template_bbox |
|
|
| def visualize_grid_attention_v2(img, attention_mask, ratio=1, cmap="jet", save_image=True, |
| save_path="./test.jpg", quality=200): |
| """ |
| img_path: image file path to load |
| save_path: image file path to save |
| attention_mask: 2-D attention map with np.array type, e.g, (h, w) or (w, h) |
| ratio: scaling factor to scale the output h and w |
| cmap: attention style, default: "jet" |
| quality: saved image quality |
| """ |
| |
| |
| img_h, img_w = 224, 224 |
| plt.clf() |
| plt.subplots(nrows=1, ncols=1, figsize=(0.02 * img_h, 0.02 * img_w)) |
|
|
| |
| |
| |
| plt.imshow(img, alpha=1) |
| plt.axis('off') |
|
|
| |
| mask = cv2.resize(attention_mask, (img_h, img_w)) |
| normed_mask = mask / mask.max() |
| normed_mask = (normed_mask * 224).astype('uint8') |
| plt.imshow(normed_mask, alpha=0.5, interpolation='nearest', cmap=cmap) |
|
|
| if save_image: |
| |
| |
| |
| |
| |
|
|
| |
| |
| plt.axis('off') |
| plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) |
| plt.margins(0, 0) |
| plt.savefig(save_path, dpi=quality) |
| |
|
|
|
|
|
|
| class TargetStateFusion(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.film_ln = nn.LayerNorm(dim) |
| self.film = nn.Linear(dim, dim * 2) |
| self.film_gate = nn.Parameter(torch.full((dim,), -4.0)) |
|
|
| def modulate_feature(self, opt_feat, z_target): |
| z = self.film_ln(z_target) |
| gamma, beta = self.film(z).chunk(2, dim=-1) |
| gate = torch.sigmoid(self.film_gate) |
| gamma = gamma[:, :, None, None] * gate[None, :, None, None] |
| beta = beta[:, :, None, None] * gate[None, :, None, None] |
| return opt_feat * (1.0 + gamma) + beta |
|
|
|
|
| class ATCTRACK(BaseTracker): |
| @staticmethod |
| def _restore_target_state_embedding_row(network, checkpoint_dict): |
| row = checkpoint_dict.get('target_state_embedding', None) |
| if row is None or not hasattr(network, 'target_state_encoder') or network.target_state_encoder is None: |
| return |
| encoder = network.target_state_encoder |
| token = row.get('token') |
| token_id = int(row.get('token_id')) |
| if token != encoder.token or token_id != int(encoder.target_token_id): |
| raise RuntimeError( |
| 'Target-state token mismatch: checkpoint has {} / {}, current model has {} / {}'.format( |
| token, token_id, encoder.token, int(encoder.target_token_id) |
| ) |
| ) |
| weight = row.get('weight') |
| with torch.no_grad(): |
| embedding = encoder.qwen.get_input_embeddings().weight |
| embedding[token_id].copy_(weight.to(device=embedding.device, dtype=embedding.dtype)) |
|
|
| def _load_checkpoint(self, network, checkpoint): |
| state_dict = checkpoint['net'] |
| if getattr(network, 'target_state_encoder', None) is None: |
| state_dict = {k: v for k, v in state_dict.items() if not k.startswith('target_state_encoder.')} |
| if bool(checkpoint.get('lightweight_checkpoint', False)): |
| missing_keys, unexpected_keys = network.load_state_dict(state_dict, strict=False) |
| self._restore_target_state_embedding_row(network, checkpoint) |
| print( |
| 'Loaded lightweight tracker checkpoint from {} with {} missing keys and {} unexpected keys.'.format( |
| self.params.checkpoint, len(missing_keys), len(unexpected_keys) |
| ) |
| ) |
| else: |
| network.load_state_dict(state_dict, strict=True) |
| print("load from ", self.params.checkpoint) |
|
|
| @staticmethod |
| def _tensor_to_data_url(image_arr): |
| image = Image.fromarray(np.asarray(image_arr).astype(np.uint8)) |
| buffer = BytesIO() |
| image.save(buffer, format='JPEG') |
| return 'data:image/jpeg;base64,' + base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
| @staticmethod |
| def _parse_update_decision(text): |
| if not text: |
| return False |
| text_l = text.lower() |
| answer_start = text_l.rfind('<answer>') |
| answer_end = text_l.find('</answer>', answer_start + len('<answer>')) if answer_start >= 0 else -1 |
| answer = text_l[answer_start + len('<answer>'):answer_end].strip() if answer_start >= 0 and answer_end >= 0 else text_l |
| answer = answer.replace('<|im_end|>', ' ').replace('<|endoftext|>', ' ') |
| tokens = answer.replace('<', ' ').replace('>', ' ').replace('/', ' ').split() |
| if 'yes' in tokens and 'no' not in tokens: |
| return True |
| if 'no' in tokens: |
| return False |
| return False |
|
|
| def _load_target_state_fusion(self, checkpoint, tracker_dim): |
| fusion = TargetStateFusion(tracker_dim).to(self.device) |
| state_dict = {} |
| net_state = checkpoint.get('net', {}) |
| for key, value in net_state.items(): |
| if key.startswith('target_state_encoder.film_ln.'): |
| state_dict[key.replace('target_state_encoder.', '', 1)] = value |
| elif key.startswith('target_state_encoder.film.'): |
| state_dict[key.replace('target_state_encoder.', '', 1)] = value |
| elif key == 'target_state_encoder.film_gate': |
| state_dict['film_gate'] = value |
| |
| |
| |
| missing, unexpected = fusion.load_state_dict(state_dict, strict=False) |
| |
| critical_missing = [k for k in missing |
| if 'film_ln' not in k] |
| if critical_missing: |
| raise RuntimeError(f'Fusion state missing critical keys: {critical_missing[:20]}') |
| if unexpected: |
| raise RuntimeError(f'Fusion state unexpected keys: {unexpected[:20]}') |
| return fusion |
|
|
| def _query_qwen_updater(self, template_arr, template_bbox, candidate_arr, candidate_bbox): |
| template_url = self._tensor_to_data_url(template_arr) |
| candidate_url = self._tensor_to_data_url(candidate_arr) |
| payload = { |
| 'template_image': template_url, |
| 'candidate_image': candidate_url, |
| 'template_bbox': [float(v) for v in template_bbox], |
| 'candidate_bbox': [float(v) for v in candidate_bbox], |
| 'caption': self.target_state_caption, |
| 'object_name': self.target_state_object_name, |
| } |
| data = json.dumps(payload).encode('utf-8') |
| req = urlrequest.Request( |
| self.qwen_updater_base_url.rstrip('/') + '/update', |
| data=data, |
| headers={'Content-Type': 'application/json'}, |
| method='POST', |
| ) |
| try: |
| with urlrequest.urlopen(req, timeout=self.qwen_updater_timeout) as resp: |
| result = json.loads(resp.read().decode('utf-8')) |
| decision = bool(result.get('decision', False)) |
| z_target = result.get('z_target', None) |
| if z_target is not None: |
| z_target = torch.tensor(z_target, dtype=torch.float32, device=self.device).view(1, -1) |
| output = result.get('output', None) |
| return decision, z_target, output |
| except Exception as exc: |
| print(f'Qwen URL updater failed: {exc}') |
| return False, None, None |
|
|
| def __init__(self, params, dataset_name): |
| super(ATCTRACK, self).__init__(params) |
| checkpoint = torch.load(params.checkpoint, map_location='cpu', weights_only=False) |
| use_lightweight = bool(checkpoint.get('lightweight_checkpoint', False)) |
| tracker_cfg = copy.deepcopy(params.cfg) |
| if hasattr(tracker_cfg.MODEL, 'TARGET_STATE'): |
| tracker_cfg.MODEL.TARGET_STATE.ENABLE = False |
| network = build_atctrack(tracker_cfg, training=use_lightweight) |
| self._load_checkpoint(network, checkpoint) |
|
|
| self.cfg = tracker_cfg |
| self.seq_format = self.cfg.DATA.SEQ_FORMAT |
| self.num_template = self.cfg.TEST.NUM_TEMPLATES |
| self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE |
| |
| self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda() |
|
|
| self.network = network.cuda() |
| self.network.eval() |
| self.device = next(self.network.parameters()).device |
| self.network.target_state_encoder = self._load_target_state_fusion(checkpoint, self.cfg.MODEL.HIDDEN_DIM) |
| self.network.target_state_encoder.eval() |
| self.qwen_updater_base_url = os.environ.get('QWEN_UPDATER_BASE_URL', 'http://127.0.0.1:8001').rstrip('/') |
| self.qwen_updater_timeout = float(os.environ.get('QWEN_UPDATER_TIMEOUT', '30')) |
| self.preprocessor = Preprocessor() |
| self.state = None |
| self.debug = params.debug |
| self.frame_id = 0 |
|
|
| |
| DATASET_NAME = dataset_name.upper() |
| self.dataset_name = dataset_name |
| if hasattr(self.cfg.TEST.UPDATE_INTERVALS, DATASET_NAME): |
| self.update_intervals = self.cfg.TEST.UPDATE_INTERVALS[DATASET_NAME] |
| else: |
| self.update_intervals = self.cfg.TEST.UPDATE_INTERVALS.DEFAULT |
| print("Update interval is: ", self.update_intervals) |
| if hasattr(self.cfg.TEST.UPDATE_THRESHOLD, DATASET_NAME): |
| self.update_threshold = self.cfg.TEST.UPDATE_THRESHOLD[DATASET_NAME] |
| else: |
| self.update_threshold = self.cfg.TEST.UPDATE_THRESHOLD.DEFAULT |
|
|
| if self.dataset_name == "lasot_extension_subset_lang": |
| self.update_threshold = 0.85 |
| self.update_edge = 500 |
| elif self.dataset_name == "videocube_test_tiny": |
| self.update_threshold = 0.80 |
| self.update_edge = 1000 |
| elif self.dataset_name == "tnl2k": |
| self.update_threshold = 0.70 |
| self.update_edge = 1e6 |
| elif self.dataset_name == "lasot_lang": |
| self.update_threshold = 0.90 |
| self.update_edge = 1e6 |
| else: |
| self.update_threshold = 0.80 |
| self.update_edge = 1e6 |
| if os.environ.get('QWEN_UPDATER_BASE_URL') is not None: |
| self.update_threshold = 0.6 |
| print("Update threshold is: ", self.update_threshold) |
| |
| if "videocube" in self.dataset_name: |
| self.action_level = 1 |
| self.activity_level = 0 |
| self.story_level = 0 |
| print(self.dataset_name) |
|
|
|
|
| def initialize(self, image, info: dict): |
| self.seq_name = info["seq_name"] |
| |
| if 'videocube' in self.dataset_name: |
| action_level = self.action_level |
| activity_level = self.activity_level |
| story_level = self.story_level |
| self.frame_index = 0 |
| self.actions = [] |
| self.activities = [] |
| self.story = [] |
| self.action_start_frames = [] |
| self.action_end_frames = [] |
| self.activity_start_frames = [] |
| self.activity_end_frames = [] |
| self.story_start_frames = [] |
| self.story_end_frames = [] |
|
|
| seq_name = self.seq_name |
| print(seq_name) |
| dataset_tab_path = '/home/data_d/video_ds/VideoCube/VideoCube-Full/VideoCube_NL/02-activity&story/' + seq_name + '.xlsx' |
| dataset_tab = pd.read_excel(dataset_tab_path, index_col=0) |
| tab_activity = dataset_tab['activity': 'activity'] |
| tab_action = dataset_tab['action': 'action'] |
| tab_story = dataset_tab['story': 'story'] |
| for index, row in tab_action.iterrows(): |
| self.action_start_frames.append(row['start_frame']) |
| self.action_end_frames.append(row['end_frame']) |
| self.actions.append(row['description']) |
| for index, row in tab_activity.iterrows(): |
| self.activity_start_frames.append(row['start_frame']) |
| self.activity_end_frames.append(row['end_frame']) |
| self.activities.append(row['description']) |
| for index, row in tab_story.iterrows(): |
| self.story_start_frames.append(row['start_frame']) |
| self.story_end_frames.append(row['end_frame']) |
| self.story.append(row['description']) |
|
|
| if action_level: |
| info['init_nlp'] = self.actions[0] |
| print('language', info['init_nlp']) |
| elif activity_level: |
| info['init_nlp'] = self.activities[0] |
| print('language', info['init_nlp']) |
| elif story_level: |
| info['init_nlp'] = self.story[0] |
| print('language', info['init_nlp']) |
|
|
|
|
|
|
| |
| z_patch_arr, resize_factor = sample_target(image, info['init_bbox'], self.params.template_factor, |
| output_sz=self.params.template_size) |
|
|
| template = self.preprocessor.process(z_patch_arr) |
| self.target_state_template_image_arr = z_patch_arr |
|
|
| self.template_list = [template] * self.num_template |
|
|
|
|
| |
| template_bbox = info['init_bbox'] |
| resize_template_bbox = get_resize_template_bbox(template_bbox, resize_factor) |
|
|
| self.target_state_template_bbox = torch.tensor(resize_template_bbox, device=template.device).view(1, 4) |
| resize_template_bbox = [torch.tensor(resize_template_bbox).to(template.device)] |
| bbox_mask = torch.zeros((1, self.params.template_size, self.params.template_size)) |
| bbox_mask = generate_bbox_mask(bbox_mask, resize_template_bbox) |
|
|
| bbox_mask = bbox_mask.unfold(1, 16, 16).unfold(2, 16, 16) |
| bbox_mask = bbox_mask.mean(dim=(-1, -2)).view(bbox_mask.shape[0], -1).unsqueeze(-1) |
|
|
| bbox_mask = bbox_mask.to(template.device) |
|
|
| self.soft_token_template_mask = [bbox_mask,bbox_mask] |
|
|
| |
| |
| |
|
|
| self.target_state_caption = info['init_nlp'] |
| self.target_state_object_name = info.get('object_class_name', None) |
| self.text_features,self.text_subject_features, self.subject_infor_mask_pred, self.subject_infor_mask_gt = self.network.forward_text( |
| [info['init_nlp']], num_search=1, exp_subject_mask=None, |
| device=template.device) |
| self.device = template.device |
|
|
| |
| batch = template.shape[0] |
|
|
| self.state = info['init_bbox'] |
| self.frame_id = 0 |
| self.first_frame_flag = True |
| self.temporal_infor = [] |
| self.cached_target_state_z = None |
| self.cached_target_state_decision = None |
| self.cached_target_state_outputs = None |
|
|
| def track(self, image, info: dict = None): |
| |
| |
|
|
| H, W, _ = image.shape |
| self.frame_id += 1 |
|
|
| |
| if 'videocube' in self.dataset_name: |
| activity_level = self.activity_level |
| action_level = self.action_level |
| story_level = self.story_level |
|
|
| if action_level: |
| action_start_frames = self.action_start_frames |
| action_end_frames = self.action_end_frames |
| actions = self.actions |
| for i in range(0, len(action_start_frames)): |
| if self.frame_id >= action_start_frames[i] and self.frame_id <= action_end_frames[i]: |
| if self.frame_index != i: |
| self.frame_index += 1 |
| print('action_level self.frame_index', self.frame_index) |
| print('actions', actions[i]) |
| self.target_state_caption = actions[i] |
| |
| |
| self.text_features, self.text_subject_features, self.subject_infor_mask_pred, self.subject_infor_mask_gt = self.network.forward_text( |
| [actions[i]], num_search=1, exp_subject_mask=None, |
| device=self.device) |
| break |
| else: |
| continue |
| elif activity_level: |
| activity_start_frames = self.activity_start_frames |
| activity_end_frames = self.activity_end_frames |
| activities = self.activities |
| for i in range(0, len(activity_start_frames)): |
| if self.frame_id >= activity_start_frames[i] and self.frame_id <= activity_end_frames[i]: |
| if self.frame_index != i: |
| self.frame_index += 1 |
| print('activity_level self.frame_index', self.frame_index) |
| print('activities', activities[i]) |
| self.target_state_caption = activities[i] |
| |
| |
| self.text_features, self.text_subject_features, self.subject_infor_mask_pred, self.subject_infor_mask_gt = self.network.forward_text( |
| [activities[i]], num_search=1, exp_subject_mask=None, |
| device=self.device) |
| break |
| else: |
| continue |
| elif story_level: |
| story_start_frames = self.story_start_frames |
| story_end_frames = self.story_end_frames |
| story = self.story |
| for i in range(0, len(story_start_frames)): |
| if self.frame_id >= story_start_frames[i] and self.frame_id <= story_end_frames[i]: |
| if self.frame_index != i: |
| self.frame_index += 1 |
| print('story_level self.frame_index', self.frame_index) |
| print('story', story[i]) |
| self.target_state_caption = story[i] |
| self.text_features, self.text_sentence_features = self.network.forward_text( |
| [story[i]], num_search=1, device=self.device) |
|
|
| self.text_features, self.text_subject_features, self.subject_infor_mask_pred, self.subject_infor_mask_gt = self.network.forward_text( |
| [story[i]], num_search=1, exp_subject_mask=None, |
| device=self.device) |
| break |
| else: |
| continue |
|
|
| x_patch_arr, resize_factor = sample_target(image, self.state, self.params.search_factor, |
| output_sz=self.params.search_size) |
| search = self.preprocessor.process(x_patch_arr) |
| target_state_new_template_bbox = transform_image_to_crop( |
| torch.tensor(self.state, device=search.device, dtype=torch.float32), |
| torch.tensor(self.state, device=search.device, dtype=torch.float32), |
| resize_factor, |
| torch.tensor([self.params.search_size, self.params.search_size], device=search.device, dtype=torch.float32), |
| normalize=True, |
| ).view(1, 4) |
| |
|
|
| |
| |
| with torch.no_grad(): |
| out_dict = self.network(self.template_list, search, self.soft_token_template_mask, |
| exp_str = self.text_features, |
| exp_subject_mask = self.text_subject_features, |
| target_state_z = self.cached_target_state_z, |
| temporal_infor = self.temporal_infor, |
| first_frame_flag = self.first_frame_flag, |
| training=False) |
| self.first_frame_flag = False |
| self.temporal_infor = out_dict["temporal_infor"] |
|
|
| |
| pred_score_map = out_dict['score_map'] |
| response = self.output_window * pred_score_map |
| pred_boxes, best_score = self.network.box_head.cal_bbox(response, out_dict['size_map'], |
| out_dict['offset_map'], return_score=True) |
| max_score = best_score[0][0].item() |
| pred_boxes = pred_boxes.view(-1, 4) |
| |
| pred_box = (pred_boxes.mean( |
| dim=0) * self.params.search_size / resize_factor).tolist() |
| |
| self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10) |
|
|
| |
| conf_score = max_score |
| |
| |
| if self.num_template > 1 and self.frame_id < self.update_edge: |
| if (self.frame_id % self.update_intervals == 0) and (conf_score > self.update_threshold): |
| z_patch_arr, resize_factor = sample_target(image, self.state, self.params.template_factor, |
| output_sz=self.params.template_size) |
| template = self.preprocessor.process(z_patch_arr) |
|
|
| |
| template_bbox = self.state |
| resize_template_bbox = get_resize_template_bbox(template_bbox, resize_factor) |
| target_state_candidate_bbox = torch.tensor( |
| resize_template_bbox, device=template.device, dtype=torch.float32 |
| ).view(1, 4) |
|
|
| should_update_template, qwen_z_target, qwen_output = self._query_qwen_updater( |
| self.target_state_template_image_arr, |
| self.target_state_template_bbox.view(-1).detach().cpu().tolist(), |
| z_patch_arr, |
| target_state_candidate_bbox.view(-1).detach().cpu().tolist(), |
| ) |
| self.cached_target_state_outputs = qwen_output |
| if qwen_z_target is not None: |
| self.cached_target_state_z = qwen_z_target.detach() |
| self.cached_target_state_decision = bool(should_update_template) |
| if should_update_template and qwen_z_target is not None: |
| self.template_list.append(template) |
| if len(self.template_list) > self.num_template: |
| self.template_list.pop(1) |
|
|
| self.target_state_template_image_arr = z_patch_arr |
| self.target_state_template_bbox = target_state_candidate_bbox |
|
|
| resize_template_bbox = [target_state_candidate_bbox.view(-1)] |
| bbox_mask = torch.zeros((1, self.params.template_size, self.params.template_size)) |
| bbox_mask = generate_bbox_mask(bbox_mask, resize_template_bbox) |
|
|
| bbox_mask = bbox_mask.unfold(1, 16, 16).unfold(2, 16, 16) |
| bbox_mask = bbox_mask.mean(dim=(-1, -2)).view(bbox_mask.shape[0], -1).unsqueeze(-1) |
|
|
| bbox_mask = bbox_mask.to(template.device) |
|
|
| self.soft_token_template_mask.append(bbox_mask) |
| if len(self.soft_token_template_mask) > self.num_template: |
| self.soft_token_template_mask.pop(1) |
|
|
| return {"target_bbox": self.state, |
| "best_score": conf_score} |
|
|
| def map_box_back(self, pred_box: list, resize_factor: float): |
| cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] |
| cx, cy, w, h = pred_box |
| half_side = 0.5 * self.params.search_size / resize_factor |
| cx_real = cx + (cx_prev - half_side) |
| cy_real = cy + (cy_prev - half_side) |
| return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] |
|
|
| def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float): |
| cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] |
| cx, cy, w, h = pred_box.unbind(-1) |
| half_side = 0.5 * self.params.search_size / resize_factor |
| cx_real = cx + (cx_prev - half_side) |
| cy_real = cy + (cy_prev - half_side) |
| return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1) |
|
|
|
|
|
|
|
|
| def get_tracker_class(): |
| return ATCTRACK |
|
|