| """ |
| ATCTrack Model |
| """ |
| import os |
| import re |
| import base64 |
| from io import BytesIO |
|
|
| import torch |
| import math |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from lib.utils.misc import NestedTensor |
|
|
| |
| from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy, box_xyxy_to_cxcywh, box_iou |
| |
| from lib.models.aqatrack.hivit import hivit_small, hivit_base |
| from lib.models.aqatrack.itpn import itpn_base_3324_patch16_224 |
| from lib.models.aqatrack.fast_itpn import fast_itpn_base_3324_patch16_224,fast_itpn_large_2240_patch16_256 |
|
|
| from lib.models.transformers.transformer import build_rgb_det_decoder |
| from lib.models.layers.transformer_dec import build_transformer_dec,build_transformer_dec_with_mask |
|
|
| from torch.nn.modules.transformer import _get_clones |
| from lib.models.layers.head import build_box_head |
|
|
| import torch.nn.functional as F |
| from lib.models.layers.frozen_bn import FrozenBatchNorm2d |
| from transformers import BertTokenizer, BertModel, RobertaModel, RobertaTokenizerFast, AutoTokenizer, AutoProcessor |
| from PIL import Image, ImageDraw |
| from lib.models.transformers import build_decoder, VisionLanguageFusionModule, PositionEmbeddingSine1D,build_text_prompt_decoder |
|
|
|
|
| TARGET_STATE_TOKEN = "<TARGET_STATE>" |
|
|
|
|
| def _project_root(): |
| return os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../..")) |
|
|
|
|
| def _resolve_project_path(path): |
| if not path or os.path.isabs(path): |
| return path |
| candidate = os.path.abspath(os.path.join(_project_root(), path)) |
| if os.path.exists(candidate) or path.startswith((".", "..", "checkpoint", "resource")): |
| return candidate |
| return path |
|
|
|
|
| def _load_qwen_target_state_model(model_path): |
| try: |
| from transformers import AutoModelForImageTextToText |
| model_cls = AutoModelForImageTextToText |
| except ImportError: |
| from transformers import AutoModelForCausalLM |
| model_cls = AutoModelForCausalLM |
|
|
| try: |
| return model_cls.from_pretrained(model_path, trust_remote_code=True) |
| except ValueError as exc: |
| raise RuntimeError( |
| "Cannot load Qwen target-state model. The current transformers package " |
| "does not recognize this Qwen architecture. Upgrade transformers in the " |
| "training environment before enabling MODEL.TARGET_STATE." |
| ) from exc |
|
|
|
|
| class QwenTargetStateEncoder(nn.Module): |
| def __init__(self, cfg, tracker_dim): |
| super().__init__() |
| ts_cfg = cfg.MODEL.TARGET_STATE |
| self.model_path = _resolve_project_path(os.environ.get("QWEN_MODEL_PATH", ts_cfg.MODEL_PATH)) |
| self.token = getattr(ts_cfg, "TOKEN", TARGET_STATE_TOKEN) |
| self.prompt_template = getattr(ts_cfg, "PROMPT_TEMPLATE", "default") |
| self.train_token_embedding = getattr(ts_cfg, "TRAIN_TOKEN_EMBEDDING", False) |
| self.freeze_qwen = getattr(ts_cfg, "FREEZE_QWEN", True) |
| self.use_lora = getattr(ts_cfg, "USE_LORA", False) |
| self.lora_r = getattr(ts_cfg, "LORA_R", 8) |
| self.lora_alpha = getattr(ts_cfg, "LORA_ALPHA", 16) |
| self.lora_dropout = getattr(ts_cfg, "LORA_DROPOUT", 0.05) |
| self.lora_target_modules = getattr(ts_cfg, "LORA_TARGET_MODULES", [ |
| "in_proj_qkv", "out_proj", "in_proj_z", "in_proj_b", "in_proj_a", |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ]) |
| teacher_enable_env = os.environ.get("QWEN_TEACHER_ENABLE") |
| if teacher_enable_env is None: |
| self.teacher_enable = bool(getattr(ts_cfg, "TEACHER_ENABLE", False)) |
| else: |
| self.teacher_enable = teacher_enable_env.strip().lower() in ("1", "true", "yes", "on") |
| self.teacher_model = os.environ.get("QWEN_TEACHER_MODEL", getattr(ts_cfg, "TEACHER_MODEL", "qwen3.5")) |
| self.teacher_base_url = os.environ.get("QWEN_TEACHER_BASE_URL", getattr(ts_cfg, "TEACHER_BASE_URL", "http://127.0.0.1:8001/v1")) |
| self.teacher_api_key = os.environ.get("QWEN_TEACHER_API_KEY", getattr(ts_cfg, "TEACHER_API_KEY", "sk-no-key-required")) |
| self.teacher_client = None |
|
|
| self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) |
| self.tokenizer = getattr(self.processor, "tokenizer", None) |
| if self.tokenizer is None: |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.tokenizer.padding_side = "left" |
| self.qwen = _load_qwen_target_state_model(self.model_path) |
|
|
| self.target_state_special_tokens = ["<answer>", "</answer>", "<state_token>", "</state_token>", self.token] |
| special_tokens = {"additional_special_tokens": self.target_state_special_tokens} |
| num_added = self.tokenizer.add_special_tokens(special_tokens) |
| if num_added > 0: |
| self.qwen.resize_token_embeddings(len(self.tokenizer)) |
| self.target_token_id = self.tokenizer.convert_tokens_to_ids(self.token) |
| self._embedding_grad_hook = None |
|
|
| qwen_hidden_dim = self.qwen.config.text_config.hidden_size if hasattr(self.qwen.config, "text_config") else self.qwen.config.hidden_size |
| self.projector = nn.Sequential( |
| nn.Linear(qwen_hidden_dim, tracker_dim), |
| nn.LayerNorm(tracker_dim), |
| nn.GELU(), |
| nn.Linear(tracker_dim, tracker_dim), |
| ) |
| |
| |
| |
| self.film_ln = nn.LayerNorm(tracker_dim) |
| self.film = nn.Linear(tracker_dim, tracker_dim * 2) |
| self.film_gate = nn.Parameter(torch.full((tracker_dim,), -4.0)) |
|
|
| if self.freeze_qwen: |
| for p in self.qwen.parameters(): |
| p.requires_grad = False |
| if self.use_lora: |
| self._enable_qwen_lora() |
| self.configure_token_embedding_training(self.train_token_embedding) |
|
|
| |
| self.teacher_label_cache = None |
|
|
| def set_teacher_label_cache(self, cache): |
| """Attach a :class:`TeacherLabelCache` for two-stage training. |
| |
| When set, ``_query_teacher_decisions`` checks the cache before calling |
| the online teacher API. Cache misses fall back to the online teacher |
| and the result is written back to the cache. |
| """ |
| self.teacher_label_cache = cache |
|
|
| def _enable_qwen_lora(self): |
| try: |
| from peft import LoraConfig, get_peft_model |
| except ImportError as exc: |
| raise RuntimeError("MODEL.TARGET_STATE.USE_LORA=True requires the peft package.") from exc |
|
|
| target_modules = self.lora_target_modules |
| if isinstance(target_modules, str): |
| target_modules = [item.strip() for item in target_modules.split(",") if item.strip()] |
| config = LoraConfig( |
| r=self.lora_r, |
| lora_alpha=self.lora_alpha, |
| target_modules=list(target_modules), |
| lora_dropout=self.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| self.qwen = get_peft_model(self.qwen, config) |
|
|
| def configure_token_embedding_training(self, enabled): |
| embedding = self.qwen.get_input_embeddings() |
| embedding.weight.requires_grad = bool(enabled) |
| if self._embedding_grad_hook is not None: |
| self._embedding_grad_hook.remove() |
| self._embedding_grad_hook = None |
| if enabled: |
| train_token_ids = torch.tensor([self.target_token_id], dtype=torch.long) |
|
|
| def mask_embedding_grad(grad): |
| token_ids = train_token_ids.to(grad.device) |
| mask = torch.zeros((grad.shape[0],), device=grad.device, dtype=grad.dtype) |
| mask.index_fill_(0, token_ids, 1) |
| return grad * mask.view(-1, 1) |
|
|
| self._embedding_grad_hook = embedding.weight.register_hook(mask_embedding_grad) |
|
|
| def _qwen_forward_with_target_embedding(self, tokenized, labels=None): |
| return self.qwen(**tokenized, labels=labels, output_hidden_states=True, use_cache=False) |
|
|
| @staticmethod |
| def _tensor_batch_to_pil(images, boxes=None): |
| mean = images.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) |
| std = images.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) |
| images = (images.detach().float() * std + mean).clamp(0, 1) |
| images = (images * 255).byte().permute(0, 2, 3, 1).cpu().numpy() |
| pil_images = [Image.fromarray(image) for image in images] |
|
|
| if boxes is None: |
| return pil_images |
|
|
| boxes = boxes.detach().float().cpu() |
| for image, box in zip(pil_images, boxes): |
| draw = ImageDraw.Draw(image) |
| x, y, w, h = box.tolist() |
| if max(abs(x), abs(y), abs(w), abs(h)) <= 2.0: |
| img_w, img_h = image.size |
| x, w = x * img_w, w * img_w |
| y, h = y * img_h, h * img_h |
| x1 = max(0.0, min(float(image.size[0] - 1), x)) |
| y1 = max(0.0, min(float(image.size[1] - 1), y)) |
| x2 = max(0.0, min(float(image.size[0] - 1), x + w)) |
| y2 = max(0.0, min(float(image.size[1] - 1), y + h)) |
| if x2 > x1 and y2 > y1: |
| line_width = max(2, round(min(image.size) / 80)) |
| draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0), width=line_width) |
| return pil_images |
|
|
| def _build_prompt(self, caption, object_name=None): |
| caption = caption if caption else "the target object" |
| object_name = object_name if object_name else caption |
| return ( |
| f"Role: {object_name} tracking update judge and target-state token generator.\n\n" |
| "Task: Compare the targets inside the provided bboxes in Frame 1 (Original) " |
| "and Frame 2 (New). Decide whether Frame 2 should update the tracking template, " |
| "and generate a target-state token for the tracking model.\n\n" |
| "Reject update for full occlusion, out of view, too small target, severe blur/clipping, " |
| "wrong bbox, distractor, uncertain identity, or no meaningful target appearance change.\n\n" |
| "Accept update only if Frame 2 contains the same target as Frame 1, the bbox is reliable, " |
| "the target is clear, and the appearance change is useful for future tracking.\n\n" |
| "The target-state token should summarize the current target condition for the tracking model. " |
| "It should encode whether the candidate is reliable, whether the target identity is consistent, " |
| "and whether the current appearance is useful or risky for tracking.\n\n" |
| "Frame 1 (Original) is the first image. Frame 2 (New candidate/search crop) is the second image.\n\n" |
| "Output exactly one answer XML tag containing yes or no, immediately followed by one " |
| "state_token XML tag containing the special target-state token. Do not output any extra text." |
| ) |
|
|
| def _build_teacher_prompt(self, caption, object_name=None): |
| caption = caption if caption else "the target object" |
| object_name = object_name if object_name else caption |
| return ( |
| f"Role: {object_name} tracking update judge.\n" |
| "Task: Compare the targets inside the provided bboxes in Frame 1 (Original) and Frame 2 (New), " |
| "and decide whether Frame 2 should update the tracking template.\n\n" |
| "Reject update for full occlusion, out of view, too small target, severe blur/clipping, wrong bbox, " |
| "distractor, uncertain identity, or no meaningful target appearance change.\n" |
| "Accept only if Frame 2 is the same target, bbox is reliable, target is clear, and appearance change is useful.\n\n" |
| "CRITICAL: Your entire response must be ONLY one of these two strings, " |
| "with no other text, no explanation, no reasoning:\n" |
| "<answer>yes</answer>\n" |
| "<answer>no</answer>" |
| ) |
|
|
| @staticmethod |
| def _pil_to_base64_jpeg(image): |
| buffer = BytesIO() |
| image.save(buffer, format="JPEG") |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
| def _get_teacher_client(self): |
| if self.teacher_client is None: |
| try: |
| from openai import OpenAI |
| except ImportError as exc: |
| raise RuntimeError("Teacher update judge requires the openai package.") from exc |
| self.teacher_client = OpenAI(base_url=self.teacher_base_url, api_key=self.teacher_api_key, timeout=5.0) |
| return self.teacher_client |
|
|
| def _query_teacher_decisions(self, prompts, template_pils, search_pils, |
| seq_names=None, frame_ids_a=None, frame_ids_b=None): |
| """Query teacher API, with optional two-stage cache support. |
| |
| When ``seq_names`` / ``frame_ids_a`` / ``frame_ids_b`` are provided |
| and a ``teacher_label_cache`` is attached, cached decisions are used |
| directly. Cache misses fall back to the online teacher API with |
| retry logic, and the result is saved back to the cache. |
| """ |
| if not self.teacher_enable: |
| return None, None |
|
|
| batch_size = len(prompts) |
| decisions = [None] * batch_size |
| responses = [None] * batch_size |
|
|
| |
| have_frame_info = ( |
| self.teacher_label_cache is not None |
| and seq_names is not None |
| and frame_ids_a is not None |
| and frame_ids_b is not None |
| ) |
| uncached_indices = list(range(batch_size)) |
| if have_frame_info: |
| uncached_indices = [] |
| for i in range(batch_size): |
| cached = self.teacher_label_cache.get( |
| seq_names[i], frame_ids_a[i], frame_ids_b[i] |
| ) |
| if cached is not None: |
| decisions[i] = cached |
| responses[i] = f"<answer>{'yes' if cached else 'no'}</answer>" |
| else: |
| uncached_indices.append(i) |
|
|
| if not uncached_indices: |
| return decisions, responses |
|
|
| |
| import time as _time |
| client = self._get_teacher_client() |
| max_retries = 3 |
| retry_delay = 2.0 |
|
|
| for idx_in_uncached, i in enumerate(uncached_indices): |
| prompt, template_pil, search_pil = prompts[i], template_pils[i], search_pils[i] |
| base64_image1 = self._pil_to_base64_jpeg(template_pil) |
| base64_image2 = self._pil_to_base64_jpeg(search_pil) |
| messages = [{ |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image1}"}}, |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image2}"}}, |
| {"type": "text", "text": prompt}, |
| ], |
| }] |
|
|
| success = False |
| last_error = None |
| for attempt in range(1, max_retries + 1): |
| try: |
| chat_response = client.chat.completions.create( |
| model=self.teacher_model, |
| messages=messages, |
| max_tokens=8, |
| temperature=0.0, |
| top_p=1.0, |
| presence_penalty=0.0, |
| frequency_penalty=0.0, |
| extra_body={ |
| "top_k": 1, |
| "seed": 0, |
| "chat_template_kwargs": {"enable_thinking": False}, |
| "guided_choice": ["<answer>yes</answer>", "<answer>no</answer>"], |
| }, |
| ) |
| content = chat_response.choices[0].message.content |
| |
| match = re.findall(r"<answer>\s*(yes|no)\s*</answer>", content, flags=re.IGNORECASE) |
| if match: |
| decisions[i] = match[-1].lower() == "yes" |
| success = True |
| else: |
| |
| |
| |
| yes_count = len(re.findall(r'\byes\b', content, flags=re.IGNORECASE)) |
| no_count = len(re.findall(r'\bno\b', content, flags=re.IGNORECASE)) |
| if yes_count > 0 and no_count == 0: |
| decisions[i] = True |
| success = True |
| elif no_count > 0 and yes_count == 0: |
| decisions[i] = False |
| success = True |
| elif yes_count > 0 or no_count > 0: |
| |
| decisions[i] = yes_count >= no_count |
| success = True |
| else: |
| decisions[i] = None |
| last_error = f"unparseable response (no yes/no found): {content!r}" |
| responses[i] = content |
| except Exception as exc: |
| last_error = str(exc) |
| decisions[i] = None |
| responses[i] = None |
|
|
| if success: |
| break |
| if attempt < max_retries: |
| delay = retry_delay * (2 ** (attempt - 1)) |
| _time.sleep(delay) |
|
|
| if not success: |
| seq_info = "" |
| if have_frame_info: |
| seq_info = f" seq={seq_names[i]} fa={frame_ids_a[i]} fb={frame_ids_b[i]}" |
| print( |
| f"[TeacherLabel] FAILED after {max_retries} retries " |
| f"(sample {i}/{batch_size}{seq_info}): {last_error}" |
| ) |
|
|
| |
| if have_frame_info and decisions[i] is not None: |
| self.teacher_label_cache.set( |
| seq_names[i], frame_ids_a[i], frame_ids_b[i], decisions[i] |
| ) |
|
|
| |
| if idx_in_uncached < len(uncached_indices) - 1: |
| _time.sleep(0.1) |
|
|
| return decisions, responses |
|
|
| @staticmethod |
| def _parse_update_decisions(decoded_outputs): |
| decisions = [] |
| for text in decoded_outputs: |
| text_l = text.lower() |
| answer_start = text_l.rfind("<answer>") |
| answer_end = text_l.find("</answer>", answer_start + len("<answer>")) if answer_start >= 0 else -1 |
| answer = text_l[answer_start + len("<answer>"):answer_end].strip() if answer_start >= 0 and answer_end >= 0 else text_l |
| answer = answer.replace("<|im_end|>", " ").replace("<|endoftext|>", " ") |
| tokens = answer.replace("<", " ").replace(">", " ").replace("/", " ").split() |
| if "yes" in tokens and "no" not in tokens: |
| decisions.append(True) |
| elif "no" in tokens: |
| decisions.append(False) |
| else: |
| decisions.append(False) |
| return decisions |
|
|
| def _apply_qwen_chat_template(self, message): |
| try: |
| return self.processor.apply_chat_template( |
| message, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=False, |
| ) |
| except TypeError: |
| return self.processor.apply_chat_template( |
| message, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| def _target_state_answer_sequences(self): |
| outputs = [ |
| f"<answer>yes</answer><state_token>{self.token}</state_token>", |
| f"<answer>no</answer><state_token>{self.token}</state_token>", |
| ] |
| return [self.tokenizer(text, add_special_tokens=False).input_ids for text in outputs] |
|
|
| def _target_state_answer_text(self, decision): |
| answer = "yes" if decision else "no" |
| return f"<answer>{answer}</answer><state_token>{self.token}</state_token>" |
|
|
| @staticmethod |
| def _find_subsequence(sequence, subsequence): |
| if len(subsequence) == 0 or len(sequence) < len(subsequence): |
| return -1 |
| for start in range(len(sequence) - len(subsequence), -1, -1): |
| if sequence[start:start + len(subsequence)] == subsequence: |
| return start |
| return -1 |
|
|
| def _build_forward_labels(self, input_ids, decisions, valid_decisions): |
| labels = torch.full_like(input_ids, -100) |
| answer_token_positions = [] |
| target_token_positions = [] |
| yes_ids = self.tokenizer("yes", add_special_tokens=False).input_ids |
| no_ids = self.tokenizer("no", add_special_tokens=False).input_ids |
|
|
| for batch_idx, decision in enumerate(decisions): |
| answer_text = self._target_state_answer_text(decision) |
| answer_ids = self.tokenizer(answer_text, add_special_tokens=False).input_ids |
| row = input_ids[batch_idx].detach().cpu().tolist() |
| start = self._find_subsequence(row, answer_ids) |
| if start < 0: |
| start = max(0, len(row) - len(answer_ids)) |
| end = min(len(row), start + len(answer_ids)) |
| labels[batch_idx, start:end] = input_ids[batch_idx, start:end] |
|
|
| decision_ids = yes_ids if decision else no_ids |
| decision_rel = self._find_subsequence(answer_ids, decision_ids) |
| if decision_rel >= 0: |
| decision_positions = [ |
| start + decision_rel + offset |
| for offset in range(len(decision_ids)) |
| if start + decision_rel + offset < input_ids.shape[1] |
| ] |
| |
| |
| |
| for pos in decision_positions: |
| labels[batch_idx, pos] = -100 |
| if valid_decisions[batch_idx]: |
| answer_token_positions.extend((batch_idx, pos) for pos in decision_positions) |
|
|
| target_rel = self._find_subsequence(answer_ids, [self.target_token_id]) |
| if target_rel >= 0 and start + target_rel < input_ids.shape[1]: |
| target_token_positions.append((batch_idx, start + target_rel)) |
|
|
| return labels, answer_token_positions, target_token_positions |
|
|
| def _answer_loss_from_forward_logits(self, logits, input_ids, answer_token_positions): |
| valid_positions = [(b, pos) for b, pos in answer_token_positions if pos > 0] |
| if not valid_positions: |
| return logits.new_tensor(0.0) |
| pred_logits = torch.stack([logits[b, pos - 1] for b, pos in valid_positions], dim=0).float() |
| targets = torch.tensor( |
| [int(input_ids[b, pos].item()) for b, pos in valid_positions], |
| device=logits.device, |
| dtype=torch.long, |
| ) |
| return F.cross_entropy(pred_logits, targets) |
|
|
| def _student_decisions_from_forward_logits(self, logits, input_ids, answer_token_positions, batch_size): |
| yes_ids = self.tokenizer("yes", add_special_tokens=False).input_ids |
| no_ids = self.tokenizer("no", add_special_tokens=False).input_ids |
| if len(yes_ids) != 1 or len(no_ids) != 1: |
| return None |
| yes_id, no_id = int(yes_ids[0]), int(no_ids[0]) |
| scores = logits.new_full((batch_size, 2), float("nan"), dtype=torch.float32) |
| for b, pos in answer_token_positions: |
| if pos <= 0 or b < 0 or b >= batch_size: |
| continue |
| target_id = int(input_ids[b, pos].item()) |
| if target_id not in (yes_id, no_id): |
| continue |
| pred = logits[b, pos - 1].float() |
| scores[b, 0] = pred[no_id] |
| scores[b, 1] = pred[yes_id] |
| valid = ~torch.isnan(scores).any(dim=1) |
| decisions = scores[:, 1] >= scores[:, 0] |
| decisions = decisions.to(dtype=torch.bool) |
| decisions[~valid] = False |
| return decisions, valid |
|
|
| def _target_hidden_from_forward(self, hidden_states, input_ids, target_token_positions): |
| h_targets = [] |
| seq_delta = hidden_states.shape[1] - input_ids.shape[1] |
| for batch_idx in range(input_ids.shape[0]): |
| positions = [pos for b, pos in target_token_positions if b == batch_idx] |
| if positions: |
| pos = positions[-1] |
| else: |
| target_positions = input_ids[batch_idx].eq(self.target_token_id).nonzero(as_tuple=False).flatten() |
| if target_positions.numel() > 0: |
| pos = int(target_positions[-1].item()) |
| else: |
| non_pad = input_ids[batch_idx].ne(self.tokenizer.pad_token_id).nonzero(as_tuple=False).flatten() |
| pos = int(non_pad[-1].item()) if non_pad.numel() > 0 else input_ids.shape[1] - 1 |
| hidden_pos = min(max(pos + seq_delta, 0), hidden_states.shape[1] - 1) |
| h_targets.append(hidden_states[batch_idx, hidden_pos]) |
| return torch.stack(h_targets, dim=0).float() |
|
|
| def _qwen_forward_with_teacher_targets(self, texts, images, teacher_decisions, device): |
| if teacher_decisions is None: |
| raise RuntimeError( |
| "Forward target-state training requires teacher yes/no labels. " |
| "Set MODEL.TARGET_STATE.TEACHER_ENABLE=True or export QWEN_TEACHER_ENABLE=true." |
| ) |
| decisions = [bool(decision) if decision is not None else False for decision in teacher_decisions] |
| valid_decisions = [decision is not None for decision in teacher_decisions] |
| if len(decisions) != len(texts): |
| raise RuntimeError( |
| f"Teacher label count ({len(decisions)}) does not match batch size ({len(texts)})." |
| ) |
| if not any(valid_decisions): |
| |
| |
| |
| import warnings |
| warnings.warn( |
| "Teacher update judge failed for every sample in this batch; " |
| "falling back to all-no decisions.", |
| RuntimeWarning, |
| ) |
| decisions = [False] * len(texts) |
| valid_decisions = [True] * len(texts) |
|
|
| target_texts = [self._target_state_answer_text(decision) for decision in decisions] |
| full_texts = [text + target_text for text, target_text in zip(texts, target_texts)] |
| tokenized = self.processor(text=full_texts, images=images, padding=True, return_tensors="pt").to(device) |
| labels, answer_token_positions, target_token_positions = self._build_forward_labels( |
| tokenized.input_ids, decisions, valid_decisions |
| ) |
| outputs = self._qwen_forward_with_target_embedding(tokenized, labels=labels) |
| qwen_format_loss = outputs.loss if outputs.loss is not None else outputs.logits.new_tensor(0.0) |
| qwen_teacher_loss = self._answer_loss_from_forward_logits( |
| outputs.logits, tokenized.input_ids, answer_token_positions |
| ) |
| h_target = self._target_hidden_from_forward(outputs.hidden_states[-1], tokenized.input_ids, target_token_positions) |
| teacher_decision_tensor = torch.tensor(decisions, device=device, dtype=torch.bool) |
| student_decision_info = self._student_decisions_from_forward_logits( |
| outputs.logits, tokenized.input_ids, answer_token_positions, len(decisions) |
| ) |
| if student_decision_info is None: |
| update_decisions = teacher_decision_tensor |
| else: |
| student_decisions, valid_student = student_decision_info |
| update_decisions = torch.where(valid_student.to(device=device), student_decisions.to(device=device), teacher_decision_tensor) |
| teacher_labels = torch.tensor( |
| [1 if decision else 0 if valid else -1 for decision, valid in zip(decisions, valid_decisions)], |
| device=device, |
| dtype=torch.long, |
| ) |
| return h_target, update_decisions, qwen_format_loss, qwen_teacher_loss, teacher_labels |
|
|
| def _qwen_generate(self, **kwargs): |
| if self.training and hasattr(self.qwen, "get_base_model"): |
| base_model = self.qwen.get_base_model() |
| unwrapped = getattr(base_model.generate, "__wrapped__", None) |
| if unwrapped is not None: |
| with self.qwen._enable_peft_forward_hooks(**kwargs): |
| peft_args = getattr(self.qwen, "special_peft_forward_args", set()) |
| clean_kwargs = {k: v for k, v in kwargs.items() if k not in peft_args} |
| return unwrapped(base_model, **clean_kwargs) |
|
|
| generate_fn = self.qwen.generate |
| if self.training: |
| unwrapped = getattr(generate_fn, "__wrapped__", None) |
| if unwrapped is not None: |
| return unwrapped(self.qwen, **kwargs) |
| return generate_fn(**kwargs) |
|
|
| def _qwen_generation_kwargs(self, prompt_len=None): |
| eos_token_ids = [] |
| for token in ("<|im_end|>", "<|endoftext|>"): |
| token_id = self.tokenizer.convert_tokens_to_ids(token) |
| if isinstance(token_id, int) and token_id >= 0 and token_id != self.tokenizer.unk_token_id: |
| eos_token_ids.append(token_id) |
| if self.tokenizer.eos_token_id is not None: |
| eos_token_ids.append(self.tokenizer.eos_token_id) |
| eos_token_ids = list(dict.fromkeys(eos_token_ids)) |
| kwargs = { |
| "max_new_tokens": 16, |
| "do_sample": False, |
| "num_beams": 1, |
| "repetition_penalty": 1.0, |
| "eos_token_id": eos_token_ids or self.tokenizer.eos_token_id, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| } |
| if prompt_len is not None: |
| answer_sequences = self._target_state_answer_sequences() |
| stop_ids = eos_token_ids or [self.tokenizer.eos_token_id] |
|
|
| def prefix_allowed_tokens_fn(batch_id, input_ids): |
| suffix = input_ids[prompt_len:].tolist() |
| allowed = [] |
| for sequence in answer_sequences: |
| if len(suffix) <= len(sequence) and suffix == sequence[:len(suffix)]: |
| if len(suffix) == len(sequence): |
| allowed.extend(stop_ids) |
| else: |
| allowed.append(sequence[len(suffix)]) |
| return list(dict.fromkeys(allowed)) or stop_ids |
|
|
| kwargs["prefix_allowed_tokens_fn"] = prefix_allowed_tokens_fn |
| return kwargs |
|
|
| def _format_loss_from_generation_scores(self, scores, generated_suffix): |
| if scores is None or len(scores) == 0: |
| return generated_suffix.new_tensor(0.0, dtype=torch.float32) |
| num_steps = min(len(scores), generated_suffix.shape[1]) |
| logits = torch.stack(scores[:num_steps], dim=1).float() |
| targets = generated_suffix[:, :num_steps].clone() |
| if self.tokenizer.pad_token_id is not None: |
| targets[targets == self.tokenizer.pad_token_id] = -100 |
|
|
| yes_seq, no_seq = self._target_state_answer_sequences() |
| decision_step = next((i for i, (yes_id, no_id) in enumerate(zip(yes_seq, no_seq)) if yes_id != no_id), None) |
| if decision_step is not None and decision_step < targets.shape[1]: |
| targets[:, decision_step] = -100 |
|
|
| return F.cross_entropy( |
| logits.reshape(-1, logits.shape[-1]), |
| targets.reshape(-1), |
| ignore_index=-100, |
| ) |
|
|
| def _teacher_decision_loss(self, scores, teacher_decisions): |
| valid_items = [(idx, decision) for idx, decision in enumerate(teacher_decisions or []) if decision is not None] |
| if not valid_items or scores is None or len(scores) == 0: |
| device = scores[0].device if scores else self.qwen.get_input_embeddings().weight.device |
| return torch.tensor(0.0, device=device), None |
|
|
| yes_seq, no_seq = self._target_state_answer_sequences() |
| decision_step = next((i for i, (yes_id, no_id) in enumerate(zip(yes_seq, no_seq)) if yes_id != no_id), None) |
| if decision_step is None or decision_step >= len(scores): |
| return scores[0].new_tensor(0.0), None |
|
|
| batch_indices = torch.tensor([idx for idx, _ in valid_items], device=scores[decision_step].device, dtype=torch.long) |
| target_ids = torch.tensor( |
| [yes_seq[decision_step] if decision else no_seq[decision_step] for _, decision in valid_items], |
| device=scores[decision_step].device, |
| dtype=torch.long, |
| ) |
| logits = scores[decision_step].float().index_select(0, batch_indices) |
| loss = F.cross_entropy(logits, target_ids) |
| labels = torch.full((len(teacher_decisions),), -1, device=scores[decision_step].device, dtype=torch.long) |
| labels[batch_indices] = torch.tensor([1 if decision else 0 for _, decision in valid_items], device=labels.device) |
| return loss, labels |
|
|
| def _target_hidden_from_generation(self, generation_hidden_states, generated_suffix): |
| target_mask = generated_suffix.eq(self.target_token_id) |
| if target_mask.any(dim=1).all(): |
| target_pos = target_mask.float().argmax(dim=1) |
| else: |
| non_pad = generated_suffix.ne(self.tokenizer.pad_token_id) |
| target_pos = non_pad.sum(dim=1).clamp_min(1) - 1 |
|
|
| hidden_steps = generation_hidden_states or [] |
| if len(hidden_steps) == 0: |
| raise RuntimeError("Qwen generation did not return hidden states.") |
| h_targets = [] |
| for batch_idx, pos in enumerate(target_pos.detach().cpu().tolist()): |
| |
| |
| |
| step = min(pos + 1, len(hidden_steps) - 1) |
| last_hidden = hidden_steps[step][-1] |
| h_targets.append(last_hidden[batch_idx, -1]) |
| return torch.stack(h_targets, dim=0).float() |
|
|
| def forward(self, captions, template_images, search_images, template_boxes, search_boxes, device, |
| object_names=None, return_update_decision=False, |
| seq_names=None, template_frame_ids=None): |
| if object_names is None: |
| object_names = [None] * len(captions) |
| prompts = [self._build_prompt(caption, object_name) for caption, object_name in zip(captions, object_names)] |
| teacher_prompts = [self._build_teacher_prompt(caption, object_name) for caption, object_name in zip(captions, object_names)] |
| template_pils = self._tensor_batch_to_pil(template_images, template_boxes) |
| search_pils = self._tensor_batch_to_pil(search_images, search_boxes) |
|
|
| |
| cache_seq_names = None |
| cache_fa = None |
| cache_fb = None |
| if seq_names is not None and template_frame_ids is not None: |
| cache_seq_names = seq_names |
| |
| cache_fa = template_frame_ids[:, -2].detach().cpu().tolist() |
| cache_fb = template_frame_ids[:, -1].detach().cpu().tolist() |
|
|
| teacher_decisions, teacher_outputs = self._query_teacher_decisions( |
| teacher_prompts, template_pils, search_pils, |
| seq_names=cache_seq_names, frame_ids_a=cache_fa, frame_ids_b=cache_fb, |
| ) |
| messages = [] |
| for prompt, template_pil, search_pil in zip(prompts, template_pils, search_pils): |
| messages.append([ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": template_pil}, |
| {"type": "image", "image": search_pil}, |
| {"type": "text", "text": prompt}, |
| ], |
| } |
| ]) |
| texts = [self._apply_qwen_chat_template(message) for message in messages] |
| images = [[template_pil, search_pil] for template_pil, search_pil in zip(template_pils, search_pils)] |
| if self.training: |
| h_target, update_decisions, qwen_format_loss, qwen_teacher_loss, teacher_labels = self._qwen_forward_with_teacher_targets( |
| texts, images, teacher_decisions, device |
| ) |
| response_outputs = teacher_outputs |
| else: |
| tokenized = self.processor(text=texts, images=images, padding=True, return_tensors="pt").to(device) |
| generation = self._qwen_generate( |
| **tokenized, |
| **self._qwen_generation_kwargs(prompt_len=tokenized.input_ids.shape[1]), |
| return_dict_in_generate=True, |
| output_scores=True, |
| output_hidden_states=True, |
| ) |
| generated_ids = generation.sequences |
| generated_suffix = generated_ids[:, tokenized.input_ids.shape[1]:] |
| decoded_outputs = self.tokenizer.batch_decode(generated_suffix, skip_special_tokens=False) |
| update_decisions = torch.tensor( |
| self._parse_update_decisions(decoded_outputs), device=device, dtype=torch.bool |
| ) |
| qwen_format_loss = self._format_loss_from_generation_scores(generation.scores, generated_suffix) |
| qwen_teacher_loss, teacher_labels = self._teacher_decision_loss(generation.scores, teacher_decisions) |
| h_target = self._target_hidden_from_generation(generation.hidden_states, generated_suffix) |
| response_outputs = decoded_outputs |
| z_target = self.projector(h_target) |
| if return_update_decision: |
| return z_target, update_decisions, qwen_format_loss, qwen_teacher_loss, teacher_labels, response_outputs |
| return z_target |
|
|
| def modulate_feature(self, opt_feat, z_target): |
| """FiLM with per-channel learnable gate. |
| |
| Shapes: |
| opt_feat (B, C, H, W) — tracker features |
| z_target (B, C) — projected target-state embedding |
| |
| ``film_gate`` is a per-channel parameter initialised to sigmoid(-4) ≈ 0.018. |
| This means modulation starts near identity and each channel independently |
| learns how much to trust the target-state signal. |
| """ |
| z = self.film_ln(z_target) |
| gamma, beta = self.film(z).chunk(2, dim=-1) |
| gate = torch.sigmoid(self.film_gate) |
| gamma = gamma[:, :, None, None] * gate[None, :, None, None] |
| beta = beta[:, :, None, None] * gate[None, :, None, None] |
| return opt_feat * (1.0 + gamma) + beta |
| def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, |
| freeze_bn=False): |
| if freeze_bn: |
| return nn.Sequential( |
| nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, |
| padding=padding, dilation=dilation, bias=True), |
| FrozenBatchNorm2d(out_planes), |
| nn.ReLU(inplace=True)) |
| else: |
| return nn.Sequential( |
| nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, |
| padding=padding, dilation=dilation, bias=True), |
| nn.BatchNorm2d(out_planes), |
| nn.ReLU(inplace=True)) |
| class ConfidencePred(nn.Module): |
| def __init__(self): |
| super(ConfidencePred, self).__init__() |
| self.feat_sz = 24 |
| self.stride = 1 |
| self.img_sz = self.feat_sz * self.stride |
| freeze_bn = False |
|
|
| |
| self.conv1_ctr = conv(5, 16, freeze_bn=freeze_bn) |
| self.conv2_ctr = conv(16, 16 // 2, freeze_bn=freeze_bn) |
| self.conv3_ctr = conv(16 // 2, 16 // 4, freeze_bn=freeze_bn) |
| self.conv4_ctr = conv(16 // 4, 16 // 8, freeze_bn=freeze_bn) |
| self.conv5_ctr = nn.Conv2d(16 // 8, 1, kernel_size=1) |
|
|
| |
| self.fc1 = nn.Linear(256, 512) |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| self.fc2 = nn.Linear(512, 1) |
|
|
| |
| self.relu = nn.ReLU() |
| self.sigmoid = nn.Sigmoid() |
|
|
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def forward(self, x,xz_feature=None, gt_score_map=None): |
| """ Forward pass with input x. """ |
|
|
| |
| x_ctr1 = self.conv1_ctr(x) |
| x_ctr2 = self.conv2_ctr(x_ctr1) |
| x_ctr3 = self.conv3_ctr(x_ctr2) |
| x_ctr4 = self.conv4_ctr(x_ctr3) |
| score_map_ctr = self.conv5_ctr(x_ctr4) |
|
|
| |
| x = score_map_ctr.flatten(1) |
| x = self.relu(self.fc1(x)) |
|
|
| x = self.sigmoid(self.fc2(x)) |
|
|
| return x |
|
|
| class SubjectIndexPred(nn.Module): |
| def __init__(self,dim): |
| super(SubjectIndexPred, self).__init__() |
|
|
| |
| self.fc1 = nn.Linear(dim, 256) |
| self.fc2 = nn.Linear(256, 128) |
| self.fc3 = nn.Linear(128, 1) |
| self.sigmoid = nn.Sigmoid() |
|
|
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def forward(self, x): |
| """ Forward pass with input x. """ |
|
|
| |
| x = torch.relu(self.fc1(x)) |
| x = torch.relu(self.fc2(x)) |
| x = self.sigmoid(self.fc3(x)) |
|
|
| return x |
|
|
|
|
| class ATCTrack(nn.Module): |
| """ This is the base class for ATCTrack""" |
| def __init__(self, transformer, box_head, tokenizer, text_encoder, aux_loss=False, head_type="CORNER",dim=512,cfg=None): |
| """ Initializes the model. |
| Parameters: |
| encoder: torch module of the encoder to be used. See encoder.py |
| decoder: torch module of the decoder architecture. See decoder.py |
| """ |
| super().__init__() |
| self.backbone = transformer |
| self.box_head = box_head |
|
|
| self.aux_loss = aux_loss |
| self.head_type = head_type |
| if head_type == "CORNER" or head_type == "CENTER": |
| self.feat_sz_s = int(box_head.feat_sz) |
| self.feat_len_s = int(box_head.feat_sz ** 2) |
|
|
| if self.aux_loss: |
| self.box_head = _get_clones(self.box_head, 6) |
|
|
| self.dim = dim |
|
|
| self.query_len = 1 |
| self.cls_prompts_pos = nn.Embedding(num_embeddings=self.query_len, embedding_dim=self.dim ) |
| |
| self.confidence_pred = ConfidencePred() |
|
|
| |
| self.visual_temporal_fusion = build_transformer_dec_with_mask(cfg, self.dim ) |
| self.temporal_len = 4 |
| self.dy_template_pos_embed = nn.Embedding(num_embeddings=self.temporal_len, |
| embedding_dim=self.dim ) |
|
|
| |
| self.tokenizer = tokenizer |
| self.text_encoder = text_encoder |
| self.text_adj = nn.Sequential( |
| nn.Linear(768, self.dim , bias=True), |
| nn.LayerNorm(self.dim , eps=1e-12), |
| nn.Dropout(0.1), |
| ) |
|
|
| self.language_adjust = build_transformer_dec(cfg, self.dim ) |
| self.vl_fusion = VisionLanguageFusionModule(dim=self.dim , num_heads=8, attn_drop=0.1, proj_drop=0.1, |
| num_vlfusion_layers=2, |
| vl_input_type='separate') |
|
|
| self.text_pos = PositionEmbeddingSine1D(self.dim , normalize=True) |
|
|
| self.text_sub_idnex_classifier = SubjectIndexPred(self.dim) |
|
|
| self.use_target_state = getattr(cfg.MODEL.TARGET_STATE, "ENABLE", False) if hasattr(cfg.MODEL, "TARGET_STATE") else False |
| if self.use_target_state: |
| self.target_state_encoder = QwenTargetStateEncoder(cfg, self.dim) |
| else: |
| self.target_state_encoder = None |
|
|
| def forward_backbone(self, template, search, cls_token,soft_token_template_mask,x_pos): |
| |
| |
| template = [template[:,:3],template[:,3:]] |
| soft_token_template_mask = [soft_token_template_mask[:, :64], soft_token_template_mask[:, 64:]] |
|
|
| x, token_type_infor = self.backbone.forward_features_pe(z=template, x=search, soft_token_template_mask =soft_token_template_mask) |
| x, aux_dict = self.backbone.forward_features_stage3(x, cls_token,x_pos) |
| return x, aux_dict |
|
|
| def forward(self, template: torch.Tensor, |
| search: torch.Tensor, |
| soft_token_template_mask=None, |
| exp_str=None, |
| exp_subject_mask=None, |
| target_state_exp_str=None, |
| target_state_template_bbox=None, |
| target_state_new_template_bbox=None, |
| target_state_object_name=None, |
| target_state_z=None, |
| target_state_seq_names=None, |
| target_state_template_frame_ids=None, |
| temporal_infor=[], |
| first_frame_flag=False, |
| training=True): |
|
|
| b0, num_search = template[0].shape[0], len(search) |
| z_target = None |
| target_state_update_decision = None |
| qwen_format_loss = None |
| qwen_teacher_loss = None |
| qwen_teacher_labels = None |
| qwen_teacher_outputs = None |
| target_state_captions = target_state_exp_str if target_state_exp_str is not None else exp_str |
|
|
| if training: |
| search = torch.cat(search, dim=0) |
| if self.use_target_state and target_state_captions and len(template) >= 3: |
| z_target, target_state_update_decision, qwen_format_loss, qwen_teacher_loss, qwen_teacher_labels, qwen_teacher_outputs = self.target_state_encoder( |
| target_state_captions, template[-2], template[-1], target_state_template_bbox, |
| target_state_new_template_bbox, search.device, object_names=target_state_object_name, |
| return_update_decision=True, |
| seq_names=target_state_seq_names, |
| template_frame_ids=target_state_template_frame_ids, |
| ) |
| selector = target_state_update_decision.view(b0, 1, 1, 1) |
| dynamic_template = torch.where(selector, template[-1], template[-2]) |
| dynamic_mask = torch.where( |
| target_state_update_decision.view(b0, 1, 1), |
| soft_token_template_mask[-1], |
| soft_token_template_mask[-2], |
| ) |
| else: |
| dynamic_template = template[1] |
| dynamic_mask = soft_token_template_mask[1] |
| template = torch.cat([template[0], dynamic_template], dim=1) |
| soft_token_template_mask = torch.cat([soft_token_template_mask[0], dynamic_mask], dim=1) |
| template_temporal = [] |
| soft_token_template_mask_temporal = [] |
| for _ in range(num_search): |
| template_temporal.append(template) |
| soft_token_template_mask_temporal.append(soft_token_template_mask) |
| template_temporal = torch.cat(template_temporal, dim=0) |
| soft_token_template_mask_temporal = torch.cat(soft_token_template_mask_temporal,dim=0) |
|
|
| else: |
| b0 = 1 |
| if target_state_z is not None: |
| z_target = target_state_z.to(device=search.device) |
| template_temporal = torch.cat(template[:2], dim=1) |
| soft_token_template_mask_temporal = torch.cat(soft_token_template_mask[:2], dim=1) |
| elif self.use_target_state and target_state_captions and len(template) >= 3: |
| z_target, target_state_update_decision, qwen_format_loss, qwen_teacher_loss, qwen_teacher_labels, qwen_teacher_outputs = self.target_state_encoder( |
| target_state_captions, template[-2], template[-1], target_state_template_bbox, |
| target_state_new_template_bbox, search.device, object_names=target_state_object_name, |
| return_update_decision=True, |
| seq_names=target_state_seq_names, |
| template_frame_ids=target_state_template_frame_ids, |
| ) |
| dynamic_template = template[-1] if bool(target_state_update_decision[0].item()) else template[-2] |
| dynamic_mask = soft_token_template_mask[-1] if bool(target_state_update_decision[0].item()) else soft_token_template_mask[-2] |
| template_temporal = torch.cat([template[0], dynamic_template], dim=1) |
| soft_token_template_mask_temporal = torch.cat([soft_token_template_mask[0], dynamic_mask], dim=1) |
| else: |
| template_temporal = torch.cat(template[:2], dim=1) |
| soft_token_template_mask_temporal = torch.cat(soft_token_template_mask[:2], dim=1) |
|
|
| |
| |
| cls_prompts_pos = self.cls_prompts_pos.weight.unsqueeze(0) |
| x_pos_0 = torch.cat([cls_prompts_pos, self.backbone.pos_embed_z, self.backbone.pos_embed_x], dim=1) |
| |
| x_pos = x_pos_0.repeat(b0*num_search, 1, 1) |
| x, aux_dict = self.forward_backbone(template_temporal, search, None, soft_token_template_mask_temporal, |
| x_pos) |
| |
| if training: |
| if exp_str: |
| text_features, text_subject_features, subject_infor_mask_pred, subject_infor_mask_gt = self.forward_text( |
| exp_str, num_search, exp_subject_mask, device=search.device) |
| else: |
| text_features = exp_str |
| text_subject_features = exp_subject_mask |
| subject_infor_mask_pred = None |
| subject_infor_mask_gt = None |
| if z_target is not None and z_target.shape[0] == b0 and num_search > 1: |
| z_target = torch.cat([z_target for _ in range(num_search)], dim=0) |
|
|
| batch_size = text_features.tensors.shape[0] |
| text_pos = self.text_pos(text_features) |
| text_pos_0 = text_pos[:b0] |
| x_s_pos_item = x_pos_0.repeat(b0, 1, 1)[:, -self.feat_len_s:] |
| pre_temporal_pos = self.dy_template_pos_embed.weight.unsqueeze(1) |
| pre_temporal_pos = pre_temporal_pos.repeat(b0, 1, self.query_len) |
| pre_temporal_pos = pre_temporal_pos.view(b0, self.temporal_len * self.query_len, self.dim).contiguous() |
|
|
| |
| xt_data = [] |
| for temporal_index in range(num_search): |
| x_item = x[temporal_index * b0:(temporal_index + 1) * b0] |
|
|
| visual_prompts_token = x_item[:, :self.query_len, :] |
|
|
| |
| |
| |
| |
| |
| |
| x_f = x_item[:, -256:] |
| x_f1 = torch.matmul(x_f, x_f.permute(0, 2, 1).contiguous()) |
| x_f = torch.matmul(x_f1, x_f) |
|
|
| z_f = x_item[:, :-256] |
|
|
| x_z = torch.matmul(x_f, z_f.permute(0, 2, 1).contiguous()) |
| att_map = x_z.mean(-1) |
|
|
| tensor_min = torch.min(att_map) |
| tensor_max = torch.max(att_map) |
| |
| normalized_tensor = (tensor_max - att_map) / (tensor_max - tensor_min) |
|
|
| attn_xz = normalized_tensor.view(-1, 256,1).contiguous() |
|
|
| |
| if training: |
| if temporal_index == 0: |
| temporal_infor = [] |
| for _ in range(self.temporal_len): |
| temporal_infor.append(visual_prompts_token) |
| else: |
| if first_frame_flag: |
| temporal_infor = [] |
| for _ in range(self.temporal_len): |
| temporal_infor.append(visual_prompts_token) |
|
|
| temporal_infor_data = torch.cat(temporal_infor, dim=1) |
|
|
| |
| |
| l_item_initial = text_features.tensors[temporal_index * b0:(temporal_index + 1) * b0] |
| l_item_subject = text_subject_features.tensors[temporal_index * b0:(temporal_index + 1) * b0] |
| l_mask_item_0 = text_features.mask[temporal_index * b0:(temporal_index + 1) * b0] |
| temporal_mask = torch.ones((l_mask_item_0.shape[0],self.temporal_len)).bool().to(l_mask_item_0.device) |
| l_mask_item = torch.cat([l_mask_item_0, temporal_mask],dim=1) |
|
|
| l_subject_temporal = torch.cat([l_item_subject,temporal_infor_data],dim=1) |
| l_subject_temporal_pos = torch.cat([text_pos_0,pre_temporal_pos ],dim=1) |
|
|
| l_item_update,_ = self.language_adjust([l_item_initial,l_subject_temporal],None, |
| text_pos_0,l_subject_temporal_pos,l_mask_item) |
| l_all = torch.cat([ l_item_initial,l_item_update ],dim=1) |
| x_s_item = x_item[:, -self.feat_len_s:] |
| x_s_item = self.vl_fusion(x_s_item, |
| l_all, |
| query_pos=x_pos_0[:, -self.feat_len_s:], |
| memory_pos=torch.cat([text_pos_0,text_pos_0],dim=1), |
| memory_key_padding_mask=torch.cat([l_mask_item_0,l_mask_item_0],dim=1), |
| need_weights=False) |
|
|
|
|
| |
| temporal_infor_update = self.visual_temporal_fusion(temporal_infor_data, x_s_item, attn_xz,pre_temporal_pos ,kv_pos= x_s_pos_item ) |
| temporal_item = temporal_infor_update[:,-1,:].unsqueeze(1) |
|
|
| |
| enc_opt = x_s_item |
| dec_opt = temporal_item.transpose(1, 2) |
| att = torch.matmul(enc_opt, dec_opt) |
| opt = (enc_opt.unsqueeze(-1) * att.unsqueeze(-2)).permute((0, 3, 2, 1)).contiguous() |
| bs, Nq, C, HW = opt.size() |
| opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) |
| if z_target is not None: |
| z_item = z_target[temporal_index * b0:(temporal_index + 1) * b0] |
| opt_feat = self.target_state_encoder.modulate_feature(opt_feat, z_item) |
|
|
| xt_data.append(opt_feat) |
|
|
| |
| if training: |
| if temporal_index == 0: |
| temporal_infor = [] |
| for _ in range(self.temporal_len): |
| temporal_infor.append(temporal_item) |
| else: |
| temporal_infor[:-1] = temporal_infor[1:] |
| temporal_infor[-1] = temporal_item |
| else: |
| if first_frame_flag: |
| temporal_infor = [] |
| for _ in range(self.temporal_len): |
| temporal_infor.append(temporal_item) |
|
|
| else: |
| temporal_infor[:-1] = temporal_infor[1:] |
| temporal_infor[-1] = temporal_item |
|
|
|
|
| |
| xt_data = torch.cat(xt_data,dim=0) |
| out = self.forward_head(xt_data, None) |
|
|
| out.update(aux_dict) |
| out['backbone_feat'] = x |
| out['subject_infor_mask_pred'] = subject_infor_mask_pred |
| out['subject_infor_mask_gt'] = subject_infor_mask_gt |
| out['target_state_update_decision'] = target_state_update_decision |
| out['qwen_format_loss'] = qwen_format_loss |
| out['qwen_teacher_loss'] = qwen_teacher_loss |
| out['qwen_teacher_labels'] = qwen_teacher_labels |
| out['qwen_teacher_outputs'] = qwen_teacher_outputs |
|
|
| if training == False: |
| out["temporal_infor"] = temporal_infor |
|
|
| return out |
|
|
| def forward_head(self, opt_feat, gt_score_map=None): |
| """ |
| cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) |
| """ |
|
|
| |
| |
| |
| |
|
|
| bs = opt_feat.shape[0] |
| Nq = 1 |
| |
| if self.head_type == "CORNER": |
| |
| pred_box, score_map = self.box_head(opt_feat, True) |
| outputs_coord = box_xyxy_to_cxcywh(pred_box) |
| outputs_coord_new = outputs_coord.view(bs, Nq, 4).contiguous() |
| out = {'pred_boxes': outputs_coord_new, |
| 'score_map': score_map, |
| } |
| return out |
|
|
| elif self.head_type == "CENTER": |
| |
| score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map) |
| |
|
|
| score_map = torch.cat([score_map_ctr, size_map, offset_map], dim=1) |
| confidence_pred = self.confidence_pred(score_map) |
|
|
| outputs_coord = bbox |
| outputs_coord_new = outputs_coord.view(bs, Nq, 4).contiguous() |
| out = {'pred_boxes': outputs_coord_new, |
| 'score_map': score_map_ctr, |
| 'size_map': size_map, |
| 'offset_map': offset_map, |
| "confidence_pred": confidence_pred} |
| return out |
| else: |
| raise NotImplementedError |
|
|
| def forward_text(self, captions, num_search, exp_subject_mask, device): |
| tokenized = self.tokenizer(captions, padding=True, return_tensors="pt").to(device) |
| encoded_text = self.text_encoder(**tokenized) |
|
|
| text_attention_mask = tokenized.attention_mask.ne(1).bool() |
| |
|
|
| text_features = encoded_text.last_hidden_state |
| text_features = self.text_adj(text_features) |
|
|
| encodings_infor = tokenized.encodings |
|
|
| subject_infor_mask_gt = None |
| if exp_subject_mask is not None: |
| |
| subject_infor_mask_gt = torch.zeros(text_attention_mask.shape[0], text_attention_mask.shape[1]).to( |
| text_features.device) |
|
|
| for item_index, item in enumerate(encodings_infor): |
| word_ids_item = item.word_ids |
| exp_subject_mask_item = exp_subject_mask[item_index] |
| text_index_list = [] |
| for word_index, word_item in enumerate(word_ids_item): |
| if word_item in exp_subject_mask_item: |
| text_index_list.append(word_index) |
|
|
| subject_infor_mask_gt[item_index, text_index_list] = 1 |
|
|
| subject_infor_mask_pred = self.text_sub_idnex_classifier(text_features) |
| subject_infor_mask_pred_1 = subject_infor_mask_pred.expand_as(text_features) |
|
|
| subject_infor = text_features * subject_infor_mask_pred_1 |
|
|
| |
| text_features_t = [] |
| text_attention_mask_t = [] |
| text_subject_infor_t = [] |
| for i in range(num_search): |
| text_features_t.append(text_features) |
| text_attention_mask_t.append(text_attention_mask) |
| text_subject_infor_t.append(subject_infor) |
|
|
| text_features = torch.cat(text_features_t, dim=0) |
| text_attention_mask = torch.cat(text_attention_mask_t, dim=0) |
| text_features = NestedTensor(text_features, text_attention_mask) |
| subject_infor = torch.cat(text_subject_infor_t, dim=0) |
| subject_infor = NestedTensor(subject_infor, text_attention_mask) |
|
|
| return text_features, subject_infor, subject_infor_mask_pred, subject_infor_mask_gt |
|
|
|
|
| class MLP(nn.Module): |
| """ Very simple multi-layer perceptron (also called FFN)""" |
|
|
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
| super().__init__() |
| self.num_layers = num_layers |
| h = [hidden_dim] * (num_layers - 1) |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
| def forward(self, x): |
| for i, layer in enumerate(self.layers): |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
| return x |
|
|
| def build_atctrack(cfg, training=True): |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| pretrained_path = os.path.join(current_dir, '../../../resource/pretrained_models') |
|
|
| if cfg.MODEL.PRETRAIN_FILE and training and ("ATCTrack" not in cfg.MODEL.PRETRAIN_FILE) : |
| pretrained = os.path.join(pretrained_path, cfg.MODEL.PRETRAIN_FILE) |
| else: |
| pretrained = '' |
|
|
|
|
| if cfg.MODEL.BACKBONE.TYPE == 'hivit_base_adaptor': |
| backbone = hivit_base(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE) |
| hidden_dim = backbone.embed_dim |
| patch_start_index = 1 |
|
|
| elif cfg.MODEL.BACKBONE.TYPE == 'itpn_base': |
| backbone = fast_itpn_base_3324_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE) |
| hidden_dim = backbone.embed_dim |
| patch_start_index = 1 |
| elif cfg.MODEL.BACKBONE.TYPE == 'itpn_large': |
| backbone = fast_itpn_large_2240_patch16_256(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE) |
| hidden_dim = backbone.embed_dim |
| patch_start_index = 1 |
|
|
| else: |
| raise NotImplementedError |
|
|
| backbone.finetune_track(cfg=cfg,dim=hidden_dim, patch_start_index=patch_start_index) |
|
|
| box_head = build_box_head(cfg, hidden_dim) |
|
|
| |
| roberta_path = _resolve_project_path(os.environ.get("ROBERTA_MODEL_PATH", os.path.join(pretrained_path, 'roberta-base'))) |
| tokenizer = RobertaTokenizerFast.from_pretrained(roberta_path) |
| text_encoder = RobertaModel.from_pretrained(roberta_path) |
|
|
|
|
| model = ATCTrack( |
| backbone, |
| box_head, |
| tokenizer, |
| text_encoder, |
| aux_loss=False, |
| head_type=cfg.MODEL.HEAD.TYPE, |
| dim = hidden_dim, |
| cfg=cfg |
| ) |
|
|
| pretrained_checkpoint = _resolve_project_path(cfg.MODEL.PRETRAINED_PATH) |
| if ("ATCTrack" in pretrained_checkpoint) and training: |
| checkpoint = torch.load(pretrained_checkpoint, map_location="cpu", weights_only=False) |
| ckpt = checkpoint["net"] |
| model_weight = {} |
| for k, v in ckpt.items(): |
| model_weight[k] = v |
|
|
| missing_keys, unexpected_keys = model.load_state_dict(model_weight, strict=False) |
| print('Load pretrained model from: ' + cfg.MODEL.PRETRAIN_FILE) |
|
|
|
|
| return model |
|
|
| def load_pretrained(model, pretrained_path, strict=False): |
|
|
| model_ckpt = torch.load(pretrained_path, map_location="cpu") |
| state_dict = model_ckpt['net'] |
| pos_st = state_dict['encoder.body.pos_embed'] |
| pos_s = pos_st[:,:(pos_st.size(1) // 2)] |
| pos_t = pos_st[:,(pos_st.size(1) // 2):] |
| state_dict['encoder.body.pos_embed_search'] = pos_s |
| state_dict['encoder.body.pos_embed_template'] = pos_t |
| state_dict['encoder.body.patch_embed_interface.proj.weight'] = state_dict['encoder.body.patch_embed.proj.weight'] |
| state_dict['encoder.body.patch_embed_interface.proj.bias'] = state_dict['encoder.body.patch_embed.proj.bias'] |
| state_dict['decoder.embedding.prompt_embeddings.weight'] = model.state_dict()['decoder.embedding.prompt_embeddings.weight'] |
| state_dict['decoder.embedding.prompt_embeddings.weight'][:] = state_dict['decoder.embedding.word_embeddings.weight'][-1] |
| del state_dict['encoder.body.pos_embed'] |
| model.load_state_dict(state_dict, strict=strict) |
|
|