# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import torch import torch.nn.functional as F import torch.distributions as dists from typing import Dict, Optional def get_token_ids_from_config(config) -> Dict[str, int]: """Extract all token IDs from the configuration object. Args: config: Configuration object (LocateAnythingConfig or similar) Returns: Dictionary containing all token IDs """ token_ids = {} # Get from main config token_ids['box_start_token_id'] = getattr(config, 'box_start_token_id', 151668) token_ids['box_end_token_id'] = getattr(config, 'box_end_token_id', 151669) token_ids['coord_start_token_id'] = getattr(config, 'coord_start_token_id', 151677) token_ids['coord_end_token_id'] = getattr(config, 'coord_end_token_id', 152677) token_ids['ref_start_token_id'] = getattr(config, 'ref_start_token_id', 151672) token_ids['ref_end_token_id'] = getattr(config, 'ref_end_token_id', 151673) token_ids['none_token_id'] = getattr(config, 'none_token_id', 4064) # Get from text_config text_config = getattr(config, 'text_config', None) if text_config is not None: token_ids['null_token_id'] = getattr(text_config, 'null_token_id', 152678) token_ids['im_end_token_id'] = getattr(text_config, 'eos_token_id', 151645) token_ids['switch_token_id'] = getattr(text_config, 'switch_token_id', 152679) token_ids['default_mask_token_id'] = getattr(text_config, 'text_mask_token_id', 151676) else: token_ids['null_token_id'] = 152678 token_ids['im_end_token_id'] = 151645 token_ids['switch_token_id'] = 152679 token_ids['default_mask_token_id'] = 151676 return token_ids def top_p_logits( logits: torch.Tensor, top_p: float = None ) -> torch.Tensor: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) return logits def top_k_logits( logits: torch.Tensor, top_k: int = None ) -> torch.Tensor: top_k = min(top_k, logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) return logits def apply_repetition_penalty( logits: torch.Tensor, input_ids: torch.Tensor, repetition_penalty: float = 1.0 ) -> torch.Tensor: """ Apply repetition penalty to logits. Args: logits: Shape [batch_size, seq_len, vocab_size] or [batch_size, vocab_size] input_ids: Previously generated token ids, shape [batch_size, seq_len] repetition_penalty: Penalty factor. > 1.0 penalizes repetition, < 1.0 encourages it. Returns: Modified logits with repetition penalty applied. """ if repetition_penalty == 1.0: return logits # Convert to 3D for vectorized computation if logits.dim() == 2: logits = logits.unsqueeze(1) # [B, 1, V] squeeze_back = True else: squeeze_back = False batch_size, seq_len, vocab_size = logits.shape # Construct [B, V] bool mask marking tokens that have appeared in each batch device = logits.device token_mask = torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) for b in range(batch_size): # Apply penalty only based on tokens already generated in this batch unique_tokens = input_ids[b].unique() # Prevent out-of-bounds: only keep IDs within vocab range valid_tokens = unique_tokens[(unique_tokens >= 0) & (unique_tokens < vocab_size)] if valid_tokens.numel() > 0: token_mask[b, valid_tokens] = True # Expand to [B, L, V] to align with logits token_mask = token_mask.unsqueeze(1).expand(-1, seq_len, -1) # Divide positive values by penalty, multiply negative values by penalty positive = logits > 0 negative = ~positive # Apply penalty only at mask positions logits = torch.where(token_mask & positive, logits / repetition_penalty, logits) logits = torch.where(token_mask & negative, logits * repetition_penalty, logits) if squeeze_back: logits = logits.squeeze(1) return logits def sample_tokens( logits: torch.Tensor, generated: torch.Tensor, token_ids: Dict[str, int], **generate_kwargs, ): batch_size, seq_len, vocab_size = logits.shape repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0) temperature = generate_kwargs.get('temperature', 0) top_p = generate_kwargs.get('top_p', None) top_k = generate_kwargs.get('top_k', None) # Apply repetition penalty based on all previously generated tokens if repetition_penalty != 1.0: logits = apply_repetition_penalty(logits, generated, repetition_penalty) if temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p) if top_k is not None: logits = top_k_logits(logits, top_k) probs = torch.softmax(logits, dim=-1) if temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) except Exception: confidence, x0 = probs.max(dim=-1) else: confidence, x0 = probs.max(dim=-1) if seq_len == 1: return probs, confidence, x0, None box_avg = [] fallback_box = torch.zeros(1, dtype=x0.dtype, device=x0.device) for b in range(batch_size): decoded_box = decode_bbox_avg( logits[b], probs[b], token_ids, keep_k=generate_kwargs.get('keep_k_avg', 4), generation_mode=generate_kwargs.get('generation_mode', 'hybrid'), ) if decoded_box is not None: box_avg.append(decoded_box) else: out_ref = decode_ref(logits[b], probs[b], token_ids) if out_ref is not None: box_avg.append(torch.tensor(out_ref, dtype=x0.dtype, device=x0.device)) else: box_avg.append(fallback_box) box_avg = torch.stack(box_avg) return probs, confidence, x0, box_avg def sample_tokens_ar( logits: torch.Tensor, generated: torch.Tensor, token_ids: Dict[str, int], **generate_kwargs, ): """ Lightweight sampling function for AR single-step sampling only. Args: logits: [batch_size, vocab_size] or [batch_size, 1, vocab_size] generated: [batch_size, seq_len] """ # Convert to 3D for reusing repetition penalty and clipping logic if logits.dim() == 2: logits = logits.unsqueeze(1) # [B, 1, V] batch_size, seq_len, vocab_size = logits.shape assert seq_len == 1, "sample_tokens_ar only supports single-step AR sampling (seq_len == 1)" repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0) temperature = generate_kwargs.get('temperature', 0) top_p = generate_kwargs.get('top_p', None) top_k = generate_kwargs.get('top_k', None) # Apply repetition penalty only based on historically generated tokens if repetition_penalty != 1.0: logits = apply_repetition_penalty(logits, generated, repetition_penalty) if temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p) if top_k is not None: logits = top_k_logits(logits, top_k) probs = torch.softmax(logits, dim=-1) if temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) except Exception: confidence, x0 = probs.max(dim=-1) else: # For greedy: directly take the token with maximum probability confidence, x0 = probs.max(dim=-1) # Keep interface consistent with sample_tokens: return [B, 1, V] / [B, 1] shape return probs, confidence, x0, None, None def is_valid_box_frame( probs, token_ids: Dict[str, int], start_thresh=0.6, end_thresh=0.2, topk=5, ): box_start_token_id = token_ids['box_start_token_id'] box_end_token_id = token_ids['box_end_token_id'] null_token_id = token_ids['null_token_id'] im_end_token_id = token_ids['im_end_token_id'] none_token_id = token_ids['none_token_id'] # none p_start = probs[0, box_start_token_id] if p_start >= start_thresh: if (probs[1, none_token_id] > 0.2 and probs[2, box_end_token_id] > 0.2 and probs[3, null_token_id] > 0.1 and probs[4, null_token_id] > 0.1): return 'empty_box' end_target_ids = torch.tensor([box_end_token_id, null_token_id, im_end_token_id], device=probs.device) end_score = probs[5, end_target_ids].sum() if end_score >= end_thresh: return 'legal_box' return 'illegal_box' def decode_bbox_avg( logits, probs, token_ids: Dict[str, int], keep_k=5, start_thresh=0.7, end_thresh=0.2, generation_mode: str = 'hybrid', ): """ Decode bounding box coordinates using top-k weighted average. Args: logits: Logits of shape (6, vocab_size) probs: Probability distribution of shape (6, vocab_size) token_ids: Dictionary containing all token IDs keep_k: Number of top-k candidate tokens to keep at each position start_thresh: Confidence threshold for box start token end_thresh: Confidence threshold for box end token Returns: Decoded bounding box coordinate list in format [box_start, x1, x2, y1, y2, box_end], or None if decoding fails """ coord_start_token_id = token_ids['coord_start_token_id'] coord_end_token_id = token_ids['coord_end_token_id'] box_start_token_id = token_ids['box_start_token_id'] box_end_token_id = token_ids['box_end_token_id'] none_token_id = token_ids['none_token_id'] device = logits.device box_type = is_valid_box_frame( probs, token_ids, start_thresh=start_thresh, end_thresh=end_thresh, topk=keep_k ) if box_type == 'empty_box': # Handle the none case first return torch.tensor([ box_start_token_id, none_token_id, box_end_token_id, token_ids['null_token_id'], token_ids['null_token_id'], token_ids['null_token_id'] ], dtype=torch.long, device=probs.device) elif box_type == 'illegal_box': return None # Extract probabilities at positions 1-4 and compute Top-K for all 4 positions at once pos_probs, pos_ids = torch.topk(probs[1:5], k=keep_k, dim=-1) mask = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id) has_valid = mask.any(dim=-1) # shape: [4] if not has_valid.all(): return None # not a box, exit... first_valid_idx = mask.long().argmax(dim=-1, keepdim=True) # [4, 1] # Extract highest-probability valid_probs[0] and corresponding valid_ids[0] first_valid_probs = pos_probs.gather(-1, first_valid_idx).squeeze(-1) # [4] first_valid_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # [4] if generation_mode == 'hybrid': valid_counts = mask.sum(dim=-1) # [4] # Compute max/min of valid ids: fill invalid positions with extreme values to avoid interfering with max/min LARGE_NUM, SMALL_NUM = 999999, -999999 valid_ids_for_max = torch.where(mask, pos_ids, torch.tensor(SMALL_NUM, device=device)) valid_ids_for_min = torch.where(mask, pos_ids, torch.tensor(LARGE_NUM, device=device)) valid_max = valid_ids_for_max.max(dim=-1)[0] valid_min = valid_ids_for_min.min(dim=-1)[0] is_abnormal = (first_valid_probs < 0.9) & (valid_counts > 1) & ((valid_max - valid_min) > 60) # is_abnormal = (first_valid_probs < 0.7) & (valid_counts > 1) & ((valid_max - valid_min) > 80) # Normal positions take top-1 (first_valid_ids); abnormal positions are replaced with 0 final_coords = torch.where(is_abnormal, torch.tensor(0, device=pos_ids.device), first_valid_ids) elif generation_mode == 'fast': final_coords = first_valid_ids start_t = torch.tensor([box_start_token_id], dtype=final_coords.dtype, device=device) end_t = torch.tensor([box_end_token_id], dtype=final_coords.dtype, device=device) return torch.cat([start_t, final_coords, end_t]) def decode_ref( logits, probs, token_ids: Dict[str, int], keep_k=5, start_thresh=0.6, ): ref_start_token_id = token_ids.get('ref_start_token_id') coord_start_token_id = token_ids['coord_start_token_id'] coord_end_token_id = token_ids['coord_end_token_id'] device = probs.device L = probs.size(0) # 1. Check if the first position is and its probability meets start_thresh # Note: we directly use the probability of the ref token at position 0 for the check if probs[0, ref_start_token_id] < start_thresh: return None # 2. Extract Top-K probabilities and token IDs for all subsequent positions pos_probs, pos_ids = torch.topk(probs[1:], k=keep_k, dim=-1) # shape: [L-1, keep_k] # 3. Build mask: identify coordinate tokens (<0> ~ <1000>) is_coord = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id) # Invert: valid tokens are non-coordinate tokens is_valid = ~is_coord # shape: [L-1, keep_k] # Ensure each position has at least one non-coordinate valid token in its Top-K has_valid = is_valid.any(dim=-1) # shape: [L-1] if not has_valid.all(): return None # 4. Get the highest-probability valid token # Since topk results are sorted in descending order of probability, # argmax returns the first index where is_valid is True, i.e., the index of the most probable valid token first_valid_idx = is_valid.long().argmax(dim=-1, keepdim=True) # shape: [L-1, 1] # Extract the final token IDs final_text_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # shape: [L-1] start_t = torch.tensor([ref_start_token_id], dtype=final_text_ids.dtype, device=device) return torch.cat([start_t, final_text_ids]) def handle_pattern(x0, token_ids: Dict[str, int], generation_mode: str = 'hybrid'): """ Args: x0: Token ID list of length 6 token_ids: Dictionary containing all token IDs """ null_token_id = token_ids['null_token_id'] im_end_token_id = token_ids['im_end_token_id'] box_start_token_id = token_ids['box_start_token_id'] box_end_token_id = token_ids['box_end_token_id'] none_token_id = token_ids['none_token_id'] coord_start_token_id = token_ids['coord_start_token_id'] coord_end_token_id = token_ids['coord_end_token_id'] ref_end_token_id = token_ids['ref_end_token_id'] x0 = x0.tolist() if x0[0] == null_token_id: return { "type": "im_end", "tokens": [im_end_token_id], "need_switch_to_ar": False, "is_terminal": True, } elif x0[0] == im_end_token_id: return { "type": "im_end", "tokens": [im_end_token_id], "need_switch_to_ar": False, "is_terminal": True, } elif x0[:2] == [box_start_token_id, none_token_id]: return { "type": "empty_box", "tokens": [box_start_token_id, none_token_id, box_end_token_id], "need_switch_to_ar": False, "is_terminal": False, } elif x0[0] == box_start_token_id: coord_ix = 1 for coord in x0[1:5]: if coord_start_token_id <= coord <= coord_end_token_id: coord_ix += 1 else: break # Standard 4-coordinate bbox: if coord_ix == 5 and x0[5] == box_end_token_id: return { "type": "coord_box", "tokens": x0, "need_switch_to_ar": False, "is_terminal": False, } # Two-coordinate pointing: # Convention: the first two coordinates are valid coord tokens, the third token is box_end. # Remaining positions (if any) are not part of the pattern; truncate at box_end. elif coord_ix == 3 and x0[3] == box_end_token_id: return { "type": "point_box", "tokens": x0[:4], "need_switch_to_ar": False, "is_terminal": False, } else: if generation_mode == 'fast': # fast mode: treat as coord_box, stay in MTP return { "type": "coord_box", "tokens": x0, "need_switch_to_ar": False, "is_terminal": False, } else: # hybrid mode: error_box, switch to AR return { "type": "error_box", "tokens": x0[:coord_ix], "need_switch_to_ar": True, "is_terminal": False, } else: for i, token in enumerate(x0): if token == null_token_id: x0 = x0[:i] break if len(x0) >= 2 and x0[-1] == x0[-2] == ref_end_token_id: x0 = x0[:-1] return { "type": "ref_object", "tokens": x0, "need_switch_to_ar": False, "is_terminal": False, }