""" ATCTrack Model """ import os import re import base64 from io import BytesIO import torch import math from torch import nn import torch.nn.functional as F from lib.utils.misc import NestedTensor # from .language_model import build_bert from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy, box_xyxy_to_cxcywh, box_iou ### aqatrack from lib.models.aqatrack.hivit import hivit_small, hivit_base from lib.models.aqatrack.itpn import itpn_base_3324_patch16_224 from lib.models.aqatrack.fast_itpn import fast_itpn_base_3324_patch16_224,fast_itpn_large_2240_patch16_256 from lib.models.transformers.transformer import build_rgb_det_decoder from lib.models.layers.transformer_dec import build_transformer_dec,build_transformer_dec_with_mask from torch.nn.modules.transformer import _get_clones from lib.models.layers.head import build_box_head import torch.nn.functional as F from lib.models.layers.frozen_bn import FrozenBatchNorm2d from transformers import BertTokenizer, BertModel, RobertaModel, RobertaTokenizerFast, AutoTokenizer, AutoProcessor from PIL import Image, ImageDraw from lib.models.transformers import build_decoder, VisionLanguageFusionModule, PositionEmbeddingSine1D,build_text_prompt_decoder TARGET_STATE_TOKEN = "" def _project_root(): return os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../..")) def _resolve_project_path(path): if not path or os.path.isabs(path): return path candidate = os.path.abspath(os.path.join(_project_root(), path)) if os.path.exists(candidate) or path.startswith((".", "..", "checkpoint", "resource")): return candidate return path def _load_qwen_target_state_model(model_path): try: from transformers import AutoModelForImageTextToText model_cls = AutoModelForImageTextToText except ImportError: from transformers import AutoModelForCausalLM model_cls = AutoModelForCausalLM try: return model_cls.from_pretrained(model_path, trust_remote_code=True) except ValueError as exc: raise RuntimeError( "Cannot load Qwen target-state model. The current transformers package " "does not recognize this Qwen architecture. Upgrade transformers in the " "training environment before enabling MODEL.TARGET_STATE." ) from exc class QwenTargetStateEncoder(nn.Module): def __init__(self, cfg, tracker_dim): super().__init__() ts_cfg = cfg.MODEL.TARGET_STATE self.model_path = _resolve_project_path(os.environ.get("QWEN_MODEL_PATH", ts_cfg.MODEL_PATH)) self.token = getattr(ts_cfg, "TOKEN", TARGET_STATE_TOKEN) self.prompt_template = getattr(ts_cfg, "PROMPT_TEMPLATE", "default") self.train_token_embedding = getattr(ts_cfg, "TRAIN_TOKEN_EMBEDDING", False) self.freeze_qwen = getattr(ts_cfg, "FREEZE_QWEN", True) self.use_lora = getattr(ts_cfg, "USE_LORA", False) self.lora_r = getattr(ts_cfg, "LORA_R", 8) self.lora_alpha = getattr(ts_cfg, "LORA_ALPHA", 16) self.lora_dropout = getattr(ts_cfg, "LORA_DROPOUT", 0.05) self.lora_target_modules = getattr(ts_cfg, "LORA_TARGET_MODULES", [ "in_proj_qkv", "out_proj", "in_proj_z", "in_proj_b", "in_proj_a", "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ]) teacher_enable_env = os.environ.get("QWEN_TEACHER_ENABLE") if teacher_enable_env is None: self.teacher_enable = bool(getattr(ts_cfg, "TEACHER_ENABLE", False)) else: self.teacher_enable = teacher_enable_env.strip().lower() in ("1", "true", "yes", "on") self.teacher_model = os.environ.get("QWEN_TEACHER_MODEL", getattr(ts_cfg, "TEACHER_MODEL", "qwen3.5")) self.teacher_base_url = os.environ.get("QWEN_TEACHER_BASE_URL", getattr(ts_cfg, "TEACHER_BASE_URL", "http://127.0.0.1:8001/v1")) self.teacher_api_key = os.environ.get("QWEN_TEACHER_API_KEY", getattr(ts_cfg, "TEACHER_API_KEY", "sk-no-key-required")) self.teacher_client = None self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) self.tokenizer = getattr(self.processor, "tokenizer", None) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "left" self.qwen = _load_qwen_target_state_model(self.model_path) self.target_state_special_tokens = ["", "", "", "", self.token] special_tokens = {"additional_special_tokens": self.target_state_special_tokens} num_added = self.tokenizer.add_special_tokens(special_tokens) if num_added > 0: self.qwen.resize_token_embeddings(len(self.tokenizer)) self.target_token_id = self.tokenizer.convert_tokens_to_ids(self.token) self._embedding_grad_hook = None qwen_hidden_dim = self.qwen.config.text_config.hidden_size if hasattr(self.qwen.config, "text_config") else self.qwen.config.hidden_size self.projector = nn.Sequential( nn.Linear(qwen_hidden_dim, tracker_dim), nn.LayerNorm(tracker_dim), nn.GELU(), nn.Linear(tracker_dim, tracker_dim), ) # P1: LayerNorm stabilises z_target distribution before FiLM. # P0: per-channel gate with sigmoid(-4) ≈ 0.018 initial value, # so each channel independently learns when to trust z_target. self.film_ln = nn.LayerNorm(tracker_dim) self.film = nn.Linear(tracker_dim, tracker_dim * 2) self.film_gate = nn.Parameter(torch.full((tracker_dim,), -4.0)) if self.freeze_qwen: for p in self.qwen.parameters(): p.requires_grad = False if self.use_lora: self._enable_qwen_lora() self.configure_token_embedding_training(self.train_token_embedding) # Two-stage teacher labeling: persistent cache to avoid repeated API calls. self.teacher_label_cache = None def set_teacher_label_cache(self, cache): """Attach a :class:`TeacherLabelCache` for two-stage training. When set, ``_query_teacher_decisions`` checks the cache before calling the online teacher API. Cache misses fall back to the online teacher and the result is written back to the cache. """ self.teacher_label_cache = cache def _enable_qwen_lora(self): try: from peft import LoraConfig, get_peft_model except ImportError as exc: raise RuntimeError("MODEL.TARGET_STATE.USE_LORA=True requires the peft package.") from exc target_modules = self.lora_target_modules if isinstance(target_modules, str): target_modules = [item.strip() for item in target_modules.split(",") if item.strip()] config = LoraConfig( r=self.lora_r, lora_alpha=self.lora_alpha, target_modules=list(target_modules), lora_dropout=self.lora_dropout, bias="none", task_type="CAUSAL_LM", ) self.qwen = get_peft_model(self.qwen, config) def configure_token_embedding_training(self, enabled): embedding = self.qwen.get_input_embeddings() embedding.weight.requires_grad = bool(enabled) if self._embedding_grad_hook is not None: self._embedding_grad_hook.remove() self._embedding_grad_hook = None if enabled: train_token_ids = torch.tensor([self.target_token_id], dtype=torch.long) def mask_embedding_grad(grad): token_ids = train_token_ids.to(grad.device) mask = torch.zeros((grad.shape[0],), device=grad.device, dtype=grad.dtype) mask.index_fill_(0, token_ids, 1) return grad * mask.view(-1, 1) self._embedding_grad_hook = embedding.weight.register_hook(mask_embedding_grad) def _qwen_forward_with_target_embedding(self, tokenized, labels=None): return self.qwen(**tokenized, labels=labels, output_hidden_states=True, use_cache=False) @staticmethod def _tensor_batch_to_pil(images, boxes=None): mean = images.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) std = images.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) images = (images.detach().float() * std + mean).clamp(0, 1) images = (images * 255).byte().permute(0, 2, 3, 1).cpu().numpy() pil_images = [Image.fromarray(image) for image in images] if boxes is None: return pil_images boxes = boxes.detach().float().cpu() for image, box in zip(pil_images, boxes): draw = ImageDraw.Draw(image) x, y, w, h = box.tolist() if max(abs(x), abs(y), abs(w), abs(h)) <= 2.0: img_w, img_h = image.size x, w = x * img_w, w * img_w y, h = y * img_h, h * img_h x1 = max(0.0, min(float(image.size[0] - 1), x)) y1 = max(0.0, min(float(image.size[1] - 1), y)) x2 = max(0.0, min(float(image.size[0] - 1), x + w)) y2 = max(0.0, min(float(image.size[1] - 1), y + h)) if x2 > x1 and y2 > y1: line_width = max(2, round(min(image.size) / 80)) draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0), width=line_width) return pil_images def _build_prompt(self, caption, object_name=None): caption = caption if caption else "the target object" object_name = object_name if object_name else caption return ( f"Role: {object_name} tracking update judge and target-state token generator.\n\n" "Task: Compare the targets inside the provided bboxes in Frame 1 (Original) " "and Frame 2 (New). Decide whether Frame 2 should update the tracking template, " "and generate a target-state token for the tracking model.\n\n" "Reject update for full occlusion, out of view, too small target, severe blur/clipping, " "wrong bbox, distractor, uncertain identity, or no meaningful target appearance change.\n\n" "Accept update only if Frame 2 contains the same target as Frame 1, the bbox is reliable, " "the target is clear, and the appearance change is useful for future tracking.\n\n" "The target-state token should summarize the current target condition for the tracking model. " "It should encode whether the candidate is reliable, whether the target identity is consistent, " "and whether the current appearance is useful or risky for tracking.\n\n" "Frame 1 (Original) is the first image. Frame 2 (New candidate/search crop) is the second image.\n\n" "Output exactly one answer XML tag containing yes or no, immediately followed by one " "state_token XML tag containing the special target-state token. Do not output any extra text." ) def _build_teacher_prompt(self, caption, object_name=None): caption = caption if caption else "the target object" object_name = object_name if object_name else caption return ( f"Role: {object_name} tracking update judge.\n" "Task: Compare the targets inside the provided bboxes in Frame 1 (Original) and Frame 2 (New), " "and decide whether Frame 2 should update the tracking template.\n\n" "Reject update for full occlusion, out of view, too small target, severe blur/clipping, wrong bbox, " "distractor, uncertain identity, or no meaningful target appearance change.\n" "Accept only if Frame 2 is the same target, bbox is reliable, target is clear, and appearance change is useful.\n\n" "CRITICAL: Your entire response must be ONLY one of these two strings, " "with no other text, no explanation, no reasoning:\n" "yes\n" "no" ) @staticmethod def _pil_to_base64_jpeg(image): buffer = BytesIO() image.save(buffer, format="JPEG") return base64.b64encode(buffer.getvalue()).decode("utf-8") def _get_teacher_client(self): if self.teacher_client is None: try: from openai import OpenAI except ImportError as exc: raise RuntimeError("Teacher update judge requires the openai package.") from exc self.teacher_client = OpenAI(base_url=self.teacher_base_url, api_key=self.teacher_api_key, timeout=5.0) return self.teacher_client def _query_teacher_decisions(self, prompts, template_pils, search_pils, seq_names=None, frame_ids_a=None, frame_ids_b=None): """Query teacher API, with optional two-stage cache support. When ``seq_names`` / ``frame_ids_a`` / ``frame_ids_b`` are provided and a ``teacher_label_cache`` is attached, cached decisions are used directly. Cache misses fall back to the online teacher API with retry logic, and the result is saved back to the cache. """ if not self.teacher_enable: return None, None batch_size = len(prompts) decisions = [None] * batch_size responses = [None] * batch_size # ---- check cache first ---- have_frame_info = ( self.teacher_label_cache is not None and seq_names is not None and frame_ids_a is not None and frame_ids_b is not None ) uncached_indices = list(range(batch_size)) if have_frame_info: uncached_indices = [] for i in range(batch_size): cached = self.teacher_label_cache.get( seq_names[i], frame_ids_a[i], frame_ids_b[i] ) if cached is not None: decisions[i] = cached responses[i] = f"{'yes' if cached else 'no'}" else: uncached_indices.append(i) if not uncached_indices: return decisions, responses # ---- online teacher for uncached samples ---- import time as _time client = self._get_teacher_client() max_retries = 3 retry_delay = 2.0 # seconds, doubles each retry for idx_in_uncached, i in enumerate(uncached_indices): prompt, template_pil, search_pil = prompts[i], template_pils[i], search_pils[i] base64_image1 = self._pil_to_base64_jpeg(template_pil) base64_image2 = self._pil_to_base64_jpeg(search_pil) messages = [{ "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image1}"}}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image2}"}}, {"type": "text", "text": prompt}, ], }] success = False last_error = None for attempt in range(1, max_retries + 1): try: chat_response = client.chat.completions.create( model=self.teacher_model, messages=messages, max_tokens=8, temperature=0.0, top_p=1.0, presence_penalty=0.0, frequency_penalty=0.0, extra_body={ "top_k": 1, "seed": 0, "chat_template_kwargs": {"enable_thinking": False}, "guided_choice": ["yes", "no"], }, ) content = chat_response.choices[0].message.content # Try exact XML format first match = re.findall(r"\s*(yes|no)\s*", content, flags=re.IGNORECASE) if match: decisions[i] = match[-1].lower() == "yes" success = True else: # Fallback: extract yes/no from natural-language response. # Teacher model may ignore guided_choice and output a long # reasoning text that contains "yes" or "no". yes_count = len(re.findall(r'\byes\b', content, flags=re.IGNORECASE)) no_count = len(re.findall(r'\bno\b', content, flags=re.IGNORECASE)) if yes_count > 0 and no_count == 0: decisions[i] = True success = True elif no_count > 0 and yes_count == 0: decisions[i] = False success = True elif yes_count > 0 or no_count > 0: # Ambiguous — pick the majority decisions[i] = yes_count >= no_count success = True else: decisions[i] = None last_error = f"unparseable response (no yes/no found): {content!r}" responses[i] = content except Exception as exc: last_error = str(exc) decisions[i] = None responses[i] = None if success: break if attempt < max_retries: delay = retry_delay * (2 ** (attempt - 1)) _time.sleep(delay) if not success: seq_info = "" if have_frame_info: seq_info = f" seq={seq_names[i]} fa={frame_ids_a[i]} fb={frame_ids_b[i]}" print( f"[TeacherLabel] FAILED after {max_retries} retries " f"(sample {i}/{batch_size}{seq_info}): {last_error}" ) # write back to cache (only successes) if have_frame_info and decisions[i] is not None: self.teacher_label_cache.set( seq_names[i], frame_ids_a[i], frame_ids_b[i], decisions[i] ) # Small delay between samples to avoid overwhelming vLLM if idx_in_uncached < len(uncached_indices) - 1: _time.sleep(0.1) return decisions, responses @staticmethod def _parse_update_decisions(decoded_outputs): decisions = [] for text in decoded_outputs: text_l = text.lower() answer_start = text_l.rfind("") answer_end = text_l.find("", answer_start + len("")) if answer_start >= 0 else -1 answer = text_l[answer_start + len(""):answer_end].strip() if answer_start >= 0 and answer_end >= 0 else text_l answer = answer.replace("<|im_end|>", " ").replace("<|endoftext|>", " ") tokens = answer.replace("<", " ").replace(">", " ").replace("/", " ").split() if "yes" in tokens and "no" not in tokens: decisions.append(True) elif "no" in tokens: decisions.append(False) else: decisions.append(False) return decisions def _apply_qwen_chat_template(self, message): try: return self.processor.apply_chat_template( message, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) except TypeError: return self.processor.apply_chat_template( message, tokenize=False, add_generation_prompt=True, ) def _target_state_answer_sequences(self): outputs = [ f"yes{self.token}", f"no{self.token}", ] return [self.tokenizer(text, add_special_tokens=False).input_ids for text in outputs] def _target_state_answer_text(self, decision): answer = "yes" if decision else "no" return f"{answer}{self.token}" @staticmethod def _find_subsequence(sequence, subsequence): if len(subsequence) == 0 or len(sequence) < len(subsequence): return -1 for start in range(len(sequence) - len(subsequence), -1, -1): if sequence[start:start + len(subsequence)] == subsequence: return start return -1 def _build_forward_labels(self, input_ids, decisions, valid_decisions): labels = torch.full_like(input_ids, -100) answer_token_positions = [] target_token_positions = [] yes_ids = self.tokenizer("yes", add_special_tokens=False).input_ids no_ids = self.tokenizer("no", add_special_tokens=False).input_ids for batch_idx, decision in enumerate(decisions): answer_text = self._target_state_answer_text(decision) answer_ids = self.tokenizer(answer_text, add_special_tokens=False).input_ids row = input_ids[batch_idx].detach().cpu().tolist() start = self._find_subsequence(row, answer_ids) if start < 0: start = max(0, len(row) - len(answer_ids)) end = min(len(row), start + len(answer_ids)) labels[batch_idx, start:end] = input_ids[batch_idx, start:end] decision_ids = yes_ids if decision else no_ids decision_rel = self._find_subsequence(answer_ids, decision_ids) if decision_rel >= 0: decision_positions = [ start + decision_rel + offset for offset in range(len(decision_ids)) if start + decision_rel + offset < input_ids.shape[1] ] # Keep format loss focused on the fixed XML/state-token scaffold. # The semantic yes/no decision is supervised only by teacher loss # so it is not diluted by the much easier constant tokens. for pos in decision_positions: labels[batch_idx, pos] = -100 if valid_decisions[batch_idx]: answer_token_positions.extend((batch_idx, pos) for pos in decision_positions) target_rel = self._find_subsequence(answer_ids, [self.target_token_id]) if target_rel >= 0 and start + target_rel < input_ids.shape[1]: target_token_positions.append((batch_idx, start + target_rel)) return labels, answer_token_positions, target_token_positions def _answer_loss_from_forward_logits(self, logits, input_ids, answer_token_positions): valid_positions = [(b, pos) for b, pos in answer_token_positions if pos > 0] if not valid_positions: return logits.new_tensor(0.0) pred_logits = torch.stack([logits[b, pos - 1] for b, pos in valid_positions], dim=0).float() targets = torch.tensor( [int(input_ids[b, pos].item()) for b, pos in valid_positions], device=logits.device, dtype=torch.long, ) return F.cross_entropy(pred_logits, targets) def _student_decisions_from_forward_logits(self, logits, input_ids, answer_token_positions, batch_size): yes_ids = self.tokenizer("yes", add_special_tokens=False).input_ids no_ids = self.tokenizer("no", add_special_tokens=False).input_ids if len(yes_ids) != 1 or len(no_ids) != 1: return None yes_id, no_id = int(yes_ids[0]), int(no_ids[0]) scores = logits.new_full((batch_size, 2), float("nan"), dtype=torch.float32) for b, pos in answer_token_positions: if pos <= 0 or b < 0 or b >= batch_size: continue target_id = int(input_ids[b, pos].item()) if target_id not in (yes_id, no_id): continue pred = logits[b, pos - 1].float() scores[b, 0] = pred[no_id] scores[b, 1] = pred[yes_id] valid = ~torch.isnan(scores).any(dim=1) decisions = scores[:, 1] >= scores[:, 0] decisions = decisions.to(dtype=torch.bool) decisions[~valid] = False return decisions, valid def _target_hidden_from_forward(self, hidden_states, input_ids, target_token_positions): h_targets = [] seq_delta = hidden_states.shape[1] - input_ids.shape[1] for batch_idx in range(input_ids.shape[0]): positions = [pos for b, pos in target_token_positions if b == batch_idx] if positions: pos = positions[-1] else: target_positions = input_ids[batch_idx].eq(self.target_token_id).nonzero(as_tuple=False).flatten() if target_positions.numel() > 0: pos = int(target_positions[-1].item()) else: non_pad = input_ids[batch_idx].ne(self.tokenizer.pad_token_id).nonzero(as_tuple=False).flatten() pos = int(non_pad[-1].item()) if non_pad.numel() > 0 else input_ids.shape[1] - 1 hidden_pos = min(max(pos + seq_delta, 0), hidden_states.shape[1] - 1) h_targets.append(hidden_states[batch_idx, hidden_pos]) return torch.stack(h_targets, dim=0).float() def _qwen_forward_with_teacher_targets(self, texts, images, teacher_decisions, device): if teacher_decisions is None: raise RuntimeError( "Forward target-state training requires teacher yes/no labels. " "Set MODEL.TARGET_STATE.TEACHER_ENABLE=True or export QWEN_TEACHER_ENABLE=true." ) decisions = [bool(decision) if decision is not None else False for decision in teacher_decisions] valid_decisions = [decision is not None for decision in teacher_decisions] if len(decisions) != len(texts): raise RuntimeError( f"Teacher label count ({len(decisions)}) does not match batch size ({len(texts)})." ) if not any(valid_decisions): # Teacher failed for every sample — fall back to all-"no" so # training can continue. A warning is printed so the user can # investigate the teacher service if this happens frequently. import warnings warnings.warn( "Teacher update judge failed for every sample in this batch; " "falling back to all-no decisions.", RuntimeWarning, ) decisions = [False] * len(texts) valid_decisions = [True] * len(texts) target_texts = [self._target_state_answer_text(decision) for decision in decisions] full_texts = [text + target_text for text, target_text in zip(texts, target_texts)] tokenized = self.processor(text=full_texts, images=images, padding=True, return_tensors="pt").to(device) labels, answer_token_positions, target_token_positions = self._build_forward_labels( tokenized.input_ids, decisions, valid_decisions ) outputs = self._qwen_forward_with_target_embedding(tokenized, labels=labels) qwen_format_loss = outputs.loss if outputs.loss is not None else outputs.logits.new_tensor(0.0) qwen_teacher_loss = self._answer_loss_from_forward_logits( outputs.logits, tokenized.input_ids, answer_token_positions ) h_target = self._target_hidden_from_forward(outputs.hidden_states[-1], tokenized.input_ids, target_token_positions) teacher_decision_tensor = torch.tensor(decisions, device=device, dtype=torch.bool) student_decision_info = self._student_decisions_from_forward_logits( outputs.logits, tokenized.input_ids, answer_token_positions, len(decisions) ) if student_decision_info is None: update_decisions = teacher_decision_tensor else: student_decisions, valid_student = student_decision_info update_decisions = torch.where(valid_student.to(device=device), student_decisions.to(device=device), teacher_decision_tensor) teacher_labels = torch.tensor( [1 if decision else 0 if valid else -1 for decision, valid in zip(decisions, valid_decisions)], device=device, dtype=torch.long, ) return h_target, update_decisions, qwen_format_loss, qwen_teacher_loss, teacher_labels def _qwen_generate(self, **kwargs): if self.training and hasattr(self.qwen, "get_base_model"): base_model = self.qwen.get_base_model() unwrapped = getattr(base_model.generate, "__wrapped__", None) if unwrapped is not None: with self.qwen._enable_peft_forward_hooks(**kwargs): peft_args = getattr(self.qwen, "special_peft_forward_args", set()) clean_kwargs = {k: v for k, v in kwargs.items() if k not in peft_args} return unwrapped(base_model, **clean_kwargs) generate_fn = self.qwen.generate if self.training: unwrapped = getattr(generate_fn, "__wrapped__", None) if unwrapped is not None: return unwrapped(self.qwen, **kwargs) return generate_fn(**kwargs) def _qwen_generation_kwargs(self, prompt_len=None): eos_token_ids = [] for token in ("<|im_end|>", "<|endoftext|>"): token_id = self.tokenizer.convert_tokens_to_ids(token) if isinstance(token_id, int) and token_id >= 0 and token_id != self.tokenizer.unk_token_id: eos_token_ids.append(token_id) if self.tokenizer.eos_token_id is not None: eos_token_ids.append(self.tokenizer.eos_token_id) eos_token_ids = list(dict.fromkeys(eos_token_ids)) kwargs = { "max_new_tokens": 16, "do_sample": False, "num_beams": 1, "repetition_penalty": 1.0, "eos_token_id": eos_token_ids or self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } if prompt_len is not None: answer_sequences = self._target_state_answer_sequences() stop_ids = eos_token_ids or [self.tokenizer.eos_token_id] def prefix_allowed_tokens_fn(batch_id, input_ids): suffix = input_ids[prompt_len:].tolist() allowed = [] for sequence in answer_sequences: if len(suffix) <= len(sequence) and suffix == sequence[:len(suffix)]: if len(suffix) == len(sequence): allowed.extend(stop_ids) else: allowed.append(sequence[len(suffix)]) return list(dict.fromkeys(allowed)) or stop_ids kwargs["prefix_allowed_tokens_fn"] = prefix_allowed_tokens_fn return kwargs def _format_loss_from_generation_scores(self, scores, generated_suffix): if scores is None or len(scores) == 0: return generated_suffix.new_tensor(0.0, dtype=torch.float32) num_steps = min(len(scores), generated_suffix.shape[1]) logits = torch.stack(scores[:num_steps], dim=1).float() targets = generated_suffix[:, :num_steps].clone() if self.tokenizer.pad_token_id is not None: targets[targets == self.tokenizer.pad_token_id] = -100 yes_seq, no_seq = self._target_state_answer_sequences() decision_step = next((i for i, (yes_id, no_id) in enumerate(zip(yes_seq, no_seq)) if yes_id != no_id), None) if decision_step is not None and decision_step < targets.shape[1]: targets[:, decision_step] = -100 return F.cross_entropy( logits.reshape(-1, logits.shape[-1]), targets.reshape(-1), ignore_index=-100, ) def _teacher_decision_loss(self, scores, teacher_decisions): valid_items = [(idx, decision) for idx, decision in enumerate(teacher_decisions or []) if decision is not None] if not valid_items or scores is None or len(scores) == 0: device = scores[0].device if scores else self.qwen.get_input_embeddings().weight.device return torch.tensor(0.0, device=device), None yes_seq, no_seq = self._target_state_answer_sequences() decision_step = next((i for i, (yes_id, no_id) in enumerate(zip(yes_seq, no_seq)) if yes_id != no_id), None) if decision_step is None or decision_step >= len(scores): return scores[0].new_tensor(0.0), None batch_indices = torch.tensor([idx for idx, _ in valid_items], device=scores[decision_step].device, dtype=torch.long) target_ids = torch.tensor( [yes_seq[decision_step] if decision else no_seq[decision_step] for _, decision in valid_items], device=scores[decision_step].device, dtype=torch.long, ) logits = scores[decision_step].float().index_select(0, batch_indices) loss = F.cross_entropy(logits, target_ids) labels = torch.full((len(teacher_decisions),), -1, device=scores[decision_step].device, dtype=torch.long) labels[batch_indices] = torch.tensor([1 if decision else 0 for _, decision in valid_items], device=labels.device) return loss, labels def _target_hidden_from_generation(self, generation_hidden_states, generated_suffix): target_mask = generated_suffix.eq(self.target_token_id) if target_mask.any(dim=1).all(): target_pos = target_mask.float().argmax(dim=1) else: non_pad = generated_suffix.ne(self.tokenizer.pad_token_id) target_pos = non_pad.sum(dim=1).clamp_min(1) - 1 hidden_steps = generation_hidden_states or [] if len(hidden_steps) == 0: raise RuntimeError("Qwen generation did not return hidden states.") h_targets = [] for batch_idx, pos in enumerate(target_pos.detach().cpu().tolist()): # In cached generation, step t predicts generated token t. The hidden # state for generated token k is available at step k + 1, when that # token is fed back to predict the next token. step = min(pos + 1, len(hidden_steps) - 1) last_hidden = hidden_steps[step][-1] h_targets.append(last_hidden[batch_idx, -1]) return torch.stack(h_targets, dim=0).float() def forward(self, captions, template_images, search_images, template_boxes, search_boxes, device, object_names=None, return_update_decision=False, seq_names=None, template_frame_ids=None): if object_names is None: object_names = [None] * len(captions) prompts = [self._build_prompt(caption, object_name) for caption, object_name in zip(captions, object_names)] teacher_prompts = [self._build_teacher_prompt(caption, object_name) for caption, object_name in zip(captions, object_names)] template_pils = self._tensor_batch_to_pil(template_images, template_boxes) search_pils = self._tensor_batch_to_pil(search_images, search_boxes) # ---- resolve frame-level keys for teacher cache ---- cache_seq_names = None cache_fa = None cache_fb = None if seq_names is not None and template_frame_ids is not None: cache_seq_names = seq_names # template_frame_ids[:, -2] = old dynamic template, [:, -1] = new candidate cache_fa = template_frame_ids[:, -2].detach().cpu().tolist() cache_fb = template_frame_ids[:, -1].detach().cpu().tolist() teacher_decisions, teacher_outputs = self._query_teacher_decisions( teacher_prompts, template_pils, search_pils, seq_names=cache_seq_names, frame_ids_a=cache_fa, frame_ids_b=cache_fb, ) messages = [] for prompt, template_pil, search_pil in zip(prompts, template_pils, search_pils): messages.append([ { "role": "user", "content": [ {"type": "image", "image": template_pil}, {"type": "image", "image": search_pil}, {"type": "text", "text": prompt}, ], } ]) texts = [self._apply_qwen_chat_template(message) for message in messages] images = [[template_pil, search_pil] for template_pil, search_pil in zip(template_pils, search_pils)] if self.training: h_target, update_decisions, qwen_format_loss, qwen_teacher_loss, teacher_labels = self._qwen_forward_with_teacher_targets( texts, images, teacher_decisions, device ) response_outputs = teacher_outputs else: tokenized = self.processor(text=texts, images=images, padding=True, return_tensors="pt").to(device) generation = self._qwen_generate( **tokenized, **self._qwen_generation_kwargs(prompt_len=tokenized.input_ids.shape[1]), return_dict_in_generate=True, output_scores=True, output_hidden_states=True, ) generated_ids = generation.sequences generated_suffix = generated_ids[:, tokenized.input_ids.shape[1]:] decoded_outputs = self.tokenizer.batch_decode(generated_suffix, skip_special_tokens=False) update_decisions = torch.tensor( self._parse_update_decisions(decoded_outputs), device=device, dtype=torch.bool ) qwen_format_loss = self._format_loss_from_generation_scores(generation.scores, generated_suffix) qwen_teacher_loss, teacher_labels = self._teacher_decision_loss(generation.scores, teacher_decisions) h_target = self._target_hidden_from_generation(generation.hidden_states, generated_suffix) response_outputs = decoded_outputs z_target = self.projector(h_target) if return_update_decision: return z_target, update_decisions, qwen_format_loss, qwen_teacher_loss, teacher_labels, response_outputs return z_target def modulate_feature(self, opt_feat, z_target): """FiLM with per-channel learnable gate. Shapes: opt_feat (B, C, H, W) — tracker features z_target (B, C) — projected target-state embedding ``film_gate`` is a per-channel parameter initialised to sigmoid(-4) ≈ 0.018. This means modulation starts near identity and each channel independently learns how much to trust the target-state signal. """ z = self.film_ln(z_target) # P1: stabilise gamma, beta = self.film(z).chunk(2, dim=-1) # (B, C) each gate = torch.sigmoid(self.film_gate) # (C,) ∈ (0, 1) gamma = gamma[:, :, None, None] * gate[None, :, None, None] # (B, C, 1, 1) beta = beta[:, :, None, None] * gate[None, :, None, None] return opt_feat * (1.0 + gamma) + beta def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, freeze_bn=False): if freeze_bn: return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), FrozenBatchNorm2d(out_planes), nn.ReLU(inplace=True)) else: return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), nn.BatchNorm2d(out_planes), nn.ReLU(inplace=True)) class ConfidencePred(nn.Module): def __init__(self): super(ConfidencePred, self).__init__() self.feat_sz = 24 self.stride = 1 self.img_sz = self.feat_sz * self.stride freeze_bn = False # CNN self.conv1_ctr = conv(5, 16, freeze_bn=freeze_bn) self.conv2_ctr = conv(16, 16 // 2, freeze_bn=freeze_bn) self.conv3_ctr = conv(16 // 2, 16 // 4, freeze_bn=freeze_bn) self.conv4_ctr = conv(16 // 4, 16 // 8, freeze_bn=freeze_bn) self.conv5_ctr = nn.Conv2d(16 // 8, 1, kernel_size=1) # 定义全连接层 self.fc1 = nn.Linear(256, 512) ## cross attn 交互层 # self.multihead_attn = nn.MultiheadAttention(512, 4, dropout=0.1) # # Implementation of Feedforward model # self.dropout = nn.Dropout(0.1) # self.norm1 = nn.LayerNorm(512) self.fc2 = nn.Linear(512, 1) # 定义激活函数 self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, x,xz_feature=None, gt_score_map=None): """ Forward pass with input x. """ # ctr branch x_ctr1 = self.conv1_ctr(x) x_ctr2 = self.conv2_ctr(x_ctr1) x_ctr3 = self.conv3_ctr(x_ctr2) x_ctr4 = self.conv4_ctr(x_ctr3) score_map_ctr = self.conv5_ctr(x_ctr4) # 展平输入 x = score_map_ctr.flatten(1) x = self.relu(self.fc1(x)) x = self.sigmoid(self.fc2(x)) return x class SubjectIndexPred(nn.Module): def __init__(self,dim): super(SubjectIndexPred, self).__init__() # 定义全连接层 self.fc1 = nn.Linear(dim, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 1) self.sigmoid = nn.Sigmoid() for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, x): """ Forward pass with input x. """ # 全连接层前向传播 x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x class ATCTrack(nn.Module): """ This is the base class for ATCTrack""" def __init__(self, transformer, box_head, tokenizer, text_encoder, aux_loss=False, head_type="CORNER",dim=512,cfg=None): """ Initializes the model. Parameters: encoder: torch module of the encoder to be used. See encoder.py decoder: torch module of the decoder architecture. See decoder.py """ super().__init__() self.backbone = transformer self.box_head = box_head self.aux_loss = aux_loss self.head_type = head_type if head_type == "CORNER" or head_type == "CENTER": self.feat_sz_s = int(box_head.feat_sz) self.feat_len_s = int(box_head.feat_sz ** 2) if self.aux_loss: self.box_head = _get_clones(self.box_head, 6) self.dim = dim self.query_len = 1 self.cls_prompts_pos = nn.Embedding(num_embeddings=self.query_len, embedding_dim=self.dim ) # pos for cur query # self.cls_initial= nn.Embedding(num_embeddings=self.query_len, embedding_dim=self.dim ) # pos for cur query self.confidence_pred = ConfidencePred() ### visual temporal self.visual_temporal_fusion = build_transformer_dec_with_mask(cfg, self.dim ) self.temporal_len = 4 self.dy_template_pos_embed = nn.Embedding(num_embeddings=self.temporal_len, embedding_dim=self.dim ) # pos for cur query ## invlove_text self.tokenizer = tokenizer self.text_encoder = text_encoder self.text_adj = nn.Sequential( nn.Linear(768, self.dim , bias=True), nn.LayerNorm(self.dim , eps=1e-12), nn.Dropout(0.1), ) self.language_adjust = build_transformer_dec(cfg, self.dim ) self.vl_fusion = VisionLanguageFusionModule(dim=self.dim , num_heads=8, attn_drop=0.1, proj_drop=0.1, num_vlfusion_layers=2, vl_input_type='separate') self.text_pos = PositionEmbeddingSine1D(self.dim , normalize=True) self.text_sub_idnex_classifier = SubjectIndexPred(self.dim) self.use_target_state = getattr(cfg.MODEL.TARGET_STATE, "ENABLE", False) if hasattr(cfg.MODEL, "TARGET_STATE") else False if self.use_target_state: self.target_state_encoder = QwenTargetStateEncoder(cfg, self.dim) else: self.target_state_encoder = None def forward_backbone(self, template, search, cls_token,soft_token_template_mask,x_pos): # template b, 12, h,w # search b,6,h,w template = [template[:,:3],template[:,3:]] soft_token_template_mask = [soft_token_template_mask[:, :64], soft_token_template_mask[:, 64:]] x, token_type_infor = self.backbone.forward_features_pe(z=template, x=search, soft_token_template_mask =soft_token_template_mask) x, aux_dict = self.backbone.forward_features_stage3(x, cls_token,x_pos) return x, aux_dict def forward(self, template: torch.Tensor, search: torch.Tensor, soft_token_template_mask=None, exp_str=None, exp_subject_mask=None, target_state_exp_str=None, target_state_template_bbox=None, target_state_new_template_bbox=None, target_state_object_name=None, target_state_z=None, target_state_seq_names=None, target_state_template_frame_ids=None, temporal_infor=[], first_frame_flag=False, training=True): b0, num_search = template[0].shape[0], len(search) z_target = None target_state_update_decision = None qwen_format_loss = None qwen_teacher_loss = None qwen_teacher_labels = None qwen_teacher_outputs = None target_state_captions = target_state_exp_str if target_state_exp_str is not None else exp_str if training: search = torch.cat(search, dim=0) if self.use_target_state and target_state_captions and len(template) >= 3: z_target, target_state_update_decision, qwen_format_loss, qwen_teacher_loss, qwen_teacher_labels, qwen_teacher_outputs = self.target_state_encoder( target_state_captions, template[-2], template[-1], target_state_template_bbox, target_state_new_template_bbox, search.device, object_names=target_state_object_name, return_update_decision=True, seq_names=target_state_seq_names, template_frame_ids=target_state_template_frame_ids, ) selector = target_state_update_decision.view(b0, 1, 1, 1) dynamic_template = torch.where(selector, template[-1], template[-2]) dynamic_mask = torch.where( target_state_update_decision.view(b0, 1, 1), soft_token_template_mask[-1], soft_token_template_mask[-2], ) else: dynamic_template = template[1] dynamic_mask = soft_token_template_mask[1] template = torch.cat([template[0], dynamic_template], dim=1) soft_token_template_mask = torch.cat([soft_token_template_mask[0], dynamic_mask], dim=1) template_temporal = [] soft_token_template_mask_temporal = [] for _ in range(num_search): template_temporal.append(template) soft_token_template_mask_temporal.append(soft_token_template_mask) template_temporal = torch.cat(template_temporal, dim=0) soft_token_template_mask_temporal = torch.cat(soft_token_template_mask_temporal,dim=0) else: b0 = 1 if target_state_z is not None: z_target = target_state_z.to(device=search.device) template_temporal = torch.cat(template[:2], dim=1) soft_token_template_mask_temporal = torch.cat(soft_token_template_mask[:2], dim=1) elif self.use_target_state and target_state_captions and len(template) >= 3: z_target, target_state_update_decision, qwen_format_loss, qwen_teacher_loss, qwen_teacher_labels, qwen_teacher_outputs = self.target_state_encoder( target_state_captions, template[-2], template[-1], target_state_template_bbox, target_state_new_template_bbox, search.device, object_names=target_state_object_name, return_update_decision=True, seq_names=target_state_seq_names, template_frame_ids=target_state_template_frame_ids, ) dynamic_template = template[-1] if bool(target_state_update_decision[0].item()) else template[-2] dynamic_mask = soft_token_template_mask[-1] if bool(target_state_update_decision[0].item()) else soft_token_template_mask[-2] template_temporal = torch.cat([template[0], dynamic_template], dim=1) soft_token_template_mask_temporal = torch.cat([soft_token_template_mask[0], dynamic_mask], dim=1) else: template_temporal = torch.cat(template[:2], dim=1) soft_token_template_mask_temporal = torch.cat(soft_token_template_mask[:2], dim=1) # x, aux_dict = self.backbone(z=template, x=search, # soft_token_template_mask = soft_token_template_mask ) cls_prompts_pos = self.cls_prompts_pos.weight.unsqueeze(0) x_pos_0 = torch.cat([cls_prompts_pos, self.backbone.pos_embed_z, self.backbone.pos_embed_x], dim=1) # pos_embed = x_pos.transpose(0, 1).repeat(1, b0, 1) x_pos = x_pos_0.repeat(b0*num_search, 1, 1) x, aux_dict = self.forward_backbone(template_temporal, search, None, soft_token_template_mask_temporal, x_pos) # forward Language branch if training: if exp_str: text_features, text_subject_features, subject_infor_mask_pred, subject_infor_mask_gt = self.forward_text( exp_str, num_search, exp_subject_mask, device=search.device) # text_subject_features, subject_infor_mask_pred, subject_infor_mask_gt else: text_features = exp_str text_subject_features = exp_subject_mask subject_infor_mask_pred = None subject_infor_mask_gt = None if z_target is not None and z_target.shape[0] == b0 and num_search > 1: z_target = torch.cat([z_target for _ in range(num_search)], dim=0) batch_size = text_features.tensors.shape[0] text_pos = self.text_pos(text_features) # [batch_size, length, c] text_pos_0 = text_pos[:b0] x_s_pos_item = x_pos_0.repeat(b0, 1, 1)[:, -self.feat_len_s:] pre_temporal_pos = self.dy_template_pos_embed.weight.unsqueeze(1) pre_temporal_pos = pre_temporal_pos.repeat(b0, 1, self.query_len) pre_temporal_pos = pre_temporal_pos.view(b0, self.temporal_len * self.query_len, self.dim).contiguous() # Forward temporal xt_data = [] for temporal_index in range(num_search): x_item = x[temporal_index * b0:(temporal_index + 1) * b0] visual_prompts_token = x_item[:, :self.query_len, :] ## heatmap by backbone feat ## by attn # attn_xz = attn[:, :, :-self.feat_len_s, -self.feat_len_s:] # b,h,l,l # attn_xz_1 = attn_xz.mean(1).mean(1) # # attn_xz = attn_xz.view(16, 16) # # attn_weights_debug = attn_xz.detach().cpu().numpy() x_f = x_item[:, -256:] x_f1 = torch.matmul(x_f, x_f.permute(0, 2, 1).contiguous()) x_f = torch.matmul(x_f1, x_f) z_f = x_item[:, :-256] x_z = torch.matmul(x_f, z_f.permute(0, 2, 1).contiguous()) att_map = x_z.mean(-1) tensor_min = torch.min(att_map) tensor_max = torch.max(att_map) # normalized_tensor = (s_vl_1 - tensor_min) / (tensor_max - tensor_min) normalized_tensor = (tensor_max - att_map) / (tensor_max - tensor_min) attn_xz = normalized_tensor.view(-1, 256,1).contiguous() ### initialize & update memory if training: if temporal_index == 0: temporal_infor = [] for _ in range(self.temporal_len): temporal_infor.append(visual_prompts_token) else: if first_frame_flag: temporal_infor = [] for _ in range(self.temporal_len): temporal_infor.append(visual_prompts_token) temporal_infor_data = torch.cat(temporal_infor, dim=1) #### vl fusion ############ ## L adjust l_item_initial = text_features.tensors[temporal_index * b0:(temporal_index + 1) * b0] l_item_subject = text_subject_features.tensors[temporal_index * b0:(temporal_index + 1) * b0] l_mask_item_0 = text_features.mask[temporal_index * b0:(temporal_index + 1) * b0] temporal_mask = torch.ones((l_mask_item_0.shape[0],self.temporal_len)).bool().to(l_mask_item_0.device) l_mask_item = torch.cat([l_mask_item_0, temporal_mask],dim=1) l_subject_temporal = torch.cat([l_item_subject,temporal_infor_data],dim=1) l_subject_temporal_pos = torch.cat([text_pos_0,pre_temporal_pos ],dim=1) l_item_update,_ = self.language_adjust([l_item_initial,l_subject_temporal],None, text_pos_0,l_subject_temporal_pos,l_mask_item) l_all = torch.cat([ l_item_initial,l_item_update ],dim=1) x_s_item = x_item[:, -self.feat_len_s:] x_s_item = self.vl_fusion(x_s_item, l_all, query_pos=x_pos_0[:, -self.feat_len_s:], memory_pos=torch.cat([text_pos_0,text_pos_0],dim=1), memory_key_padding_mask=torch.cat([l_mask_item_0,l_mask_item_0],dim=1), need_weights=False) #### cross_attention with temporal_infor temporal_infor_update = self.visual_temporal_fusion(temporal_infor_data, x_s_item, attn_xz,pre_temporal_pos ,kv_pos= x_s_pos_item ) temporal_item = temporal_infor_update[:,-1,:].unsqueeze(1) # STM enc_opt = x_s_item dec_opt = temporal_item.transpose(1, 2) att = torch.matmul(enc_opt, dec_opt) opt = (enc_opt.unsqueeze(-1) * att.unsqueeze(-2)).permute((0, 3, 2, 1)).contiguous() bs, Nq, C, HW = opt.size() opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) if z_target is not None: z_item = z_target[temporal_index * b0:(temporal_index + 1) * b0] opt_feat = self.target_state_encoder.modulate_feature(opt_feat, z_item) xt_data.append(opt_feat) ### update temporal infor if training: if temporal_index == 0: temporal_infor = [] for _ in range(self.temporal_len): temporal_infor.append(temporal_item) else: temporal_infor[:-1] = temporal_infor[1:] temporal_infor[-1] = temporal_item else: if first_frame_flag: temporal_infor = [] for _ in range(self.temporal_len): temporal_infor.append(temporal_item) else: temporal_infor[:-1] = temporal_infor[1:] temporal_infor[-1] = temporal_item # Forward head xt_data = torch.cat(xt_data,dim=0) out = self.forward_head(xt_data, None) out.update(aux_dict) out['backbone_feat'] = x out['subject_infor_mask_pred'] = subject_infor_mask_pred out['subject_infor_mask_gt'] = subject_infor_mask_gt out['target_state_update_decision'] = target_state_update_decision out['qwen_format_loss'] = qwen_format_loss out['qwen_teacher_loss'] = qwen_teacher_loss out['qwen_teacher_labels'] = qwen_teacher_labels out['qwen_teacher_outputs'] = qwen_teacher_outputs if training == False: out["temporal_infor"] = temporal_infor return out def forward_head(self, opt_feat, gt_score_map=None): """ cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) """ # enc_opt = cat_feature #[:, -self.feat_len_s:] # encoder output for the search region (B, HW, C) # opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() # bs, Nq, C, HW = opt.size() # opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s).contiguous() bs = opt_feat.shape[0] Nq = 1 # Head if self.head_type == "CORNER": # run the corner head pred_box, score_map = self.box_head(opt_feat, True) outputs_coord = box_xyxy_to_cxcywh(pred_box) outputs_coord_new = outputs_coord.view(bs, Nq, 4).contiguous() out = {'pred_boxes': outputs_coord_new, 'score_map': score_map, } return out elif self.head_type == "CENTER": # run the center head score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map) # outputs_coord = box_xyxy_to_cxcywh(bbox) score_map = torch.cat([score_map_ctr, size_map, offset_map], dim=1) confidence_pred = self.confidence_pred(score_map) outputs_coord = bbox outputs_coord_new = outputs_coord.view(bs, Nq, 4).contiguous() out = {'pred_boxes': outputs_coord_new, 'score_map': score_map_ctr, 'size_map': size_map, 'offset_map': offset_map, "confidence_pred": confidence_pred} return out else: raise NotImplementedError def forward_text(self, captions, num_search, exp_subject_mask, device): tokenized = self.tokenizer(captions, padding=True, return_tensors="pt").to(device) encoded_text = self.text_encoder(**tokenized) text_attention_mask = tokenized.attention_mask.ne(1).bool() # text_attention_mask: [batch_size, length] text_features = encoded_text.last_hidden_state text_features = self.text_adj(text_features) encodings_infor = tokenized.encodings subject_infor_mask_gt = None if exp_subject_mask is not None: # train: given the exp_subject_mask, used for generating sub_index_gt subject_infor_mask_gt = torch.zeros(text_attention_mask.shape[0], text_attention_mask.shape[1]).to( text_features.device) for item_index, item in enumerate(encodings_infor): word_ids_item = item.word_ids exp_subject_mask_item = exp_subject_mask[item_index] text_index_list = [] for word_index, word_item in enumerate(word_ids_item): if word_item in exp_subject_mask_item: text_index_list.append(word_index) subject_infor_mask_gt[item_index, text_index_list] = 1 subject_infor_mask_pred = self.text_sub_idnex_classifier(text_features) subject_infor_mask_pred_1 = subject_infor_mask_pred.expand_as(text_features) subject_infor = text_features * subject_infor_mask_pred_1 # (B,L,D) to (T,B,L,D) text_features_t = [] text_attention_mask_t = [] text_subject_infor_t = [] for i in range(num_search): text_features_t.append(text_features) text_attention_mask_t.append(text_attention_mask) text_subject_infor_t.append(subject_infor) text_features = torch.cat(text_features_t, dim=0) text_attention_mask = torch.cat(text_attention_mask_t, dim=0) text_features = NestedTensor(text_features, text_attention_mask) subject_infor = torch.cat(text_subject_infor_t, dim=0) subject_infor = NestedTensor(subject_infor, text_attention_mask) return text_features, subject_infor, subject_infor_mask_pred, subject_infor_mask_gt class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def build_atctrack(cfg, training=True): current_dir = os.path.dirname(os.path.abspath(__file__)) # This is your Project Root pretrained_path = os.path.join(current_dir, '../../../resource/pretrained_models') if cfg.MODEL.PRETRAIN_FILE and training and ("ATCTrack" not in cfg.MODEL.PRETRAIN_FILE) : pretrained = os.path.join(pretrained_path, cfg.MODEL.PRETRAIN_FILE) else: pretrained = '' if cfg.MODEL.BACKBONE.TYPE == 'hivit_base_adaptor': backbone = hivit_base(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE) hidden_dim = backbone.embed_dim patch_start_index = 1 elif cfg.MODEL.BACKBONE.TYPE == 'itpn_base': # by this backbone = fast_itpn_base_3324_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE) hidden_dim = backbone.embed_dim patch_start_index = 1 elif cfg.MODEL.BACKBONE.TYPE == 'itpn_large': # by this backbone = fast_itpn_large_2240_patch16_256(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE) hidden_dim = backbone.embed_dim patch_start_index = 1 else: raise NotImplementedError backbone.finetune_track(cfg=cfg,dim=hidden_dim, patch_start_index=patch_start_index) box_head = build_box_head(cfg, hidden_dim) # Build Text Encoder roberta_path = _resolve_project_path(os.environ.get("ROBERTA_MODEL_PATH", os.path.join(pretrained_path, 'roberta-base'))) tokenizer = RobertaTokenizerFast.from_pretrained(roberta_path) # load pretrained RoBERTa Tokenizer text_encoder = RobertaModel.from_pretrained(roberta_path) # load pretrained RoBERTa model model = ATCTrack( backbone, box_head, tokenizer, text_encoder, aux_loss=False, head_type=cfg.MODEL.HEAD.TYPE, dim = hidden_dim, cfg=cfg ) pretrained_checkpoint = _resolve_project_path(cfg.MODEL.PRETRAINED_PATH) if ("ATCTrack" in pretrained_checkpoint) and training: checkpoint = torch.load(pretrained_checkpoint, map_location="cpu", weights_only=False) ckpt = checkpoint["net"] model_weight = {} for k, v in ckpt.items(): model_weight[k] = v missing_keys, unexpected_keys = model.load_state_dict(model_weight, strict=False) print('Load pretrained model from: ' + cfg.MODEL.PRETRAIN_FILE) return model def load_pretrained(model, pretrained_path, strict=False): model_ckpt = torch.load(pretrained_path, map_location="cpu") state_dict = model_ckpt['net'] pos_st = state_dict['encoder.body.pos_embed'] pos_s = pos_st[:,:(pos_st.size(1) // 2)] pos_t = pos_st[:,(pos_st.size(1) // 2):] state_dict['encoder.body.pos_embed_search'] = pos_s state_dict['encoder.body.pos_embed_template'] = pos_t state_dict['encoder.body.patch_embed_interface.proj.weight'] = state_dict['encoder.body.patch_embed.proj.weight'] state_dict['encoder.body.patch_embed_interface.proj.bias'] = state_dict['encoder.body.patch_embed.proj.bias'] state_dict['decoder.embedding.prompt_embeddings.weight'] = model.state_dict()['decoder.embedding.prompt_embeddings.weight'] state_dict['decoder.embedding.prompt_embeddings.weight'][:] = state_dict['decoder.embedding.word_embeddings.weight'][-1] del state_dict['encoder.body.pos_embed'] model.load_state_dict(state_dict, strict=strict)