# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
import torch.nn.functional as F
import torch.distributions as dists
from typing import Dict, Optional
def get_token_ids_from_config(config) -> Dict[str, int]:
"""Extract all token IDs from the configuration object.
Args:
config: Configuration object (LocateAnythingConfig or similar)
Returns:
Dictionary containing all token IDs
"""
token_ids = {}
# Get from main config
token_ids['box_start_token_id'] = getattr(config, 'box_start_token_id', 151668)
token_ids['box_end_token_id'] = getattr(config, 'box_end_token_id', 151669)
token_ids['coord_start_token_id'] = getattr(config, 'coord_start_token_id', 151677)
token_ids['coord_end_token_id'] = getattr(config, 'coord_end_token_id', 152677)
token_ids['ref_start_token_id'] = getattr(config, 'ref_start_token_id', 151672)
token_ids['ref_end_token_id'] = getattr(config, 'ref_end_token_id', 151673)
token_ids['none_token_id'] = getattr(config, 'none_token_id', 4064)
# Get from text_config
text_config = getattr(config, 'text_config', None)
if text_config is not None:
token_ids['null_token_id'] = getattr(text_config, 'null_token_id', 152678)
token_ids['im_end_token_id'] = getattr(text_config, 'eos_token_id', 151645)
token_ids['switch_token_id'] = getattr(text_config, 'switch_token_id', 152679)
token_ids['default_mask_token_id'] = getattr(text_config, 'text_mask_token_id', 151676)
else:
token_ids['null_token_id'] = 152678
token_ids['im_end_token_id'] = 151645
token_ids['switch_token_id'] = 152679
token_ids['default_mask_token_id'] = 151676
return token_ids
def top_p_logits(
logits: torch.Tensor,
top_p: float = None
) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(
logits: torch.Tensor,
top_k: int = None
) -> torch.Tensor:
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def apply_repetition_penalty(
logits: torch.Tensor,
input_ids: torch.Tensor,
repetition_penalty: float = 1.0
) -> torch.Tensor:
"""
Apply repetition penalty to logits.
Args:
logits: Shape [batch_size, seq_len, vocab_size] or [batch_size, vocab_size]
input_ids: Previously generated token ids, shape [batch_size, seq_len]
repetition_penalty: Penalty factor. > 1.0 penalizes repetition, < 1.0 encourages it.
Returns:
Modified logits with repetition penalty applied.
"""
if repetition_penalty == 1.0:
return logits
# Convert to 3D for vectorized computation
if logits.dim() == 2:
logits = logits.unsqueeze(1) # [B, 1, V]
squeeze_back = True
else:
squeeze_back = False
batch_size, seq_len, vocab_size = logits.shape
# Construct [B, V] bool mask marking tokens that have appeared in each batch
device = logits.device
token_mask = torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
for b in range(batch_size):
# Apply penalty only based on tokens already generated in this batch
unique_tokens = input_ids[b].unique()
# Prevent out-of-bounds: only keep IDs within vocab range
valid_tokens = unique_tokens[(unique_tokens >= 0) & (unique_tokens < vocab_size)]
if valid_tokens.numel() > 0:
token_mask[b, valid_tokens] = True
# Expand to [B, L, V] to align with logits
token_mask = token_mask.unsqueeze(1).expand(-1, seq_len, -1)
# Divide positive values by penalty, multiply negative values by penalty
positive = logits > 0
negative = ~positive
# Apply penalty only at mask positions
logits = torch.where(token_mask & positive, logits / repetition_penalty, logits)
logits = torch.where(token_mask & negative, logits * repetition_penalty, logits)
if squeeze_back:
logits = logits.squeeze(1)
return logits
def sample_tokens(
logits: torch.Tensor,
generated: torch.Tensor,
token_ids: Dict[str, int],
**generate_kwargs,
):
batch_size, seq_len, vocab_size = logits.shape
repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0)
temperature = generate_kwargs.get('temperature', 0)
top_p = generate_kwargs.get('top_p', None)
top_k = generate_kwargs.get('top_k', None)
# Apply repetition penalty based on all previously generated tokens
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(logits, generated, repetition_penalty)
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except Exception:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if seq_len == 1:
return probs, confidence, x0, None
box_avg = []
fallback_box = torch.zeros(1, dtype=x0.dtype, device=x0.device)
for b in range(batch_size):
decoded_box = decode_bbox_avg(
logits[b], probs[b], token_ids, keep_k=generate_kwargs.get('keep_k_avg', 4),
generation_mode=generate_kwargs.get('generation_mode', 'hybrid'),
)
if decoded_box is not None:
box_avg.append(decoded_box)
else:
out_ref = decode_ref(logits[b], probs[b], token_ids)
if out_ref is not None:
box_avg.append(torch.tensor(out_ref, dtype=x0.dtype, device=x0.device))
else:
box_avg.append(fallback_box)
box_avg = torch.stack(box_avg)
return probs, confidence, x0, box_avg
def sample_tokens_ar(
logits: torch.Tensor,
generated: torch.Tensor,
token_ids: Dict[str, int],
**generate_kwargs,
):
"""
Lightweight sampling function for AR single-step sampling only.
Args:
logits: [batch_size, vocab_size] or [batch_size, 1, vocab_size]
generated: [batch_size, seq_len]
"""
# Convert to 3D for reusing repetition penalty and clipping logic
if logits.dim() == 2:
logits = logits.unsqueeze(1) # [B, 1, V]
batch_size, seq_len, vocab_size = logits.shape
assert seq_len == 1, "sample_tokens_ar only supports single-step AR sampling (seq_len == 1)"
repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0)
temperature = generate_kwargs.get('temperature', 0)
top_p = generate_kwargs.get('top_p', None)
top_k = generate_kwargs.get('top_k', None)
# Apply repetition penalty only based on historically generated tokens
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(logits, generated, repetition_penalty)
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except Exception:
confidence, x0 = probs.max(dim=-1)
else:
# For greedy: directly take the token with maximum probability
confidence, x0 = probs.max(dim=-1)
# Keep interface consistent with sample_tokens: return [B, 1, V] / [B, 1] shape
return probs, confidence, x0, None, None
def is_valid_box_frame(
probs,
token_ids: Dict[str, int],
start_thresh=0.6,
end_thresh=0.2,
topk=5,
):
box_start_token_id = token_ids['box_start_token_id']
box_end_token_id = token_ids['box_end_token_id']
null_token_id = token_ids['null_token_id']
im_end_token_id = token_ids['im_end_token_id']
none_token_id = token_ids['none_token_id'] # none
p_start = probs[0, box_start_token_id]
if p_start >= start_thresh:
if (probs[1, none_token_id] > 0.2 and
probs[2, box_end_token_id] > 0.2 and
probs[3, null_token_id] > 0.1 and
probs[4, null_token_id] > 0.1):
return 'empty_box'
end_target_ids = torch.tensor([box_end_token_id, null_token_id, im_end_token_id], device=probs.device)
end_score = probs[5, end_target_ids].sum()
if end_score >= end_thresh:
return 'legal_box'
return 'illegal_box'
def decode_bbox_avg(
logits,
probs,
token_ids: Dict[str, int],
keep_k=5,
start_thresh=0.7,
end_thresh=0.2,
generation_mode: str = 'hybrid',
):
"""
Decode bounding box coordinates using top-k weighted average.
Args:
logits: Logits of shape (6, vocab_size)
probs: Probability distribution of shape (6, vocab_size)
token_ids: Dictionary containing all token IDs
keep_k: Number of top-k candidate tokens to keep at each position
start_thresh: Confidence threshold for box start token
end_thresh: Confidence threshold for box end token
Returns:
Decoded bounding box coordinate list in format [box_start, x1, x2, y1, y2, box_end],
or None if decoding fails
"""
coord_start_token_id = token_ids['coord_start_token_id']
coord_end_token_id = token_ids['coord_end_token_id']
box_start_token_id = token_ids['box_start_token_id']
box_end_token_id = token_ids['box_end_token_id']
none_token_id = token_ids['none_token_id']
device = logits.device
box_type = is_valid_box_frame(
probs,
token_ids,
start_thresh=start_thresh,
end_thresh=end_thresh,
topk=keep_k
)
if box_type == 'empty_box':
# Handle the none case first
return torch.tensor([
box_start_token_id,
none_token_id,
box_end_token_id,
token_ids['null_token_id'],
token_ids['null_token_id'],
token_ids['null_token_id']
], dtype=torch.long, device=probs.device)
elif box_type == 'illegal_box':
return None
# Extract probabilities at positions 1-4 and compute Top-K for all 4 positions at once
pos_probs, pos_ids = torch.topk(probs[1:5], k=keep_k, dim=-1)
mask = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id)
has_valid = mask.any(dim=-1) # shape: [4]
if not has_valid.all():
return None # not a box, exit...
first_valid_idx = mask.long().argmax(dim=-1, keepdim=True) # [4, 1]
# Extract highest-probability valid_probs[0] and corresponding valid_ids[0]
first_valid_probs = pos_probs.gather(-1, first_valid_idx).squeeze(-1) # [4]
first_valid_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # [4]
if generation_mode == 'hybrid':
valid_counts = mask.sum(dim=-1) # [4]
# Compute max/min of valid ids: fill invalid positions with extreme values to avoid interfering with max/min
LARGE_NUM, SMALL_NUM = 999999, -999999
valid_ids_for_max = torch.where(mask, pos_ids, torch.tensor(SMALL_NUM, device=device))
valid_ids_for_min = torch.where(mask, pos_ids, torch.tensor(LARGE_NUM, device=device))
valid_max = valid_ids_for_max.max(dim=-1)[0]
valid_min = valid_ids_for_min.min(dim=-1)[0]
is_abnormal = (first_valid_probs < 0.9) & (valid_counts > 1) & ((valid_max - valid_min) > 60)
# is_abnormal = (first_valid_probs < 0.7) & (valid_counts > 1) & ((valid_max - valid_min) > 80)
# Normal positions take top-1 (first_valid_ids); abnormal positions are replaced with 0
final_coords = torch.where(is_abnormal, torch.tensor(0, device=pos_ids.device), first_valid_ids)
elif generation_mode == 'fast':
final_coords = first_valid_ids
start_t = torch.tensor([box_start_token_id], dtype=final_coords.dtype, device=device)
end_t = torch.tensor([box_end_token_id], dtype=final_coords.dtype, device=device)
return torch.cat([start_t, final_coords, end_t])
def decode_ref(
logits,
probs,
token_ids: Dict[str, int],
keep_k=5,
start_thresh=0.6,
):
ref_start_token_id = token_ids.get('ref_start_token_id')
coord_start_token_id = token_ids['coord_start_token_id']
coord_end_token_id = token_ids['coord_end_token_id']
device = probs.device
L = probs.size(0)
# 1. Check if the first position is [ and its probability meets start_thresh
# Note: we directly use the probability of the ref token at position 0 for the check
if probs[0, ref_start_token_id] < start_thresh:
return None
# 2. Extract Top-K probabilities and token IDs for all subsequent positions
pos_probs, pos_ids = torch.topk(probs[1:], k=keep_k, dim=-1) # shape: [L-1, keep_k]
# 3. Build mask: identify coordinate tokens (<0> ~ <1000>)
is_coord = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id)
# Invert: valid tokens are non-coordinate tokens
is_valid = ~is_coord # shape: [L-1, keep_k]
# Ensure each position has at least one non-coordinate valid token in its Top-K
has_valid = is_valid.any(dim=-1) # shape: [L-1]
if not has_valid.all():
return None
# 4. Get the highest-probability valid token
# Since topk results are sorted in descending order of probability,
# argmax returns the first index where is_valid is True, i.e., the index of the most probable valid token
first_valid_idx = is_valid.long().argmax(dim=-1, keepdim=True) # shape: [L-1, 1]
# Extract the final token IDs
final_text_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # shape: [L-1]
start_t = torch.tensor([ref_start_token_id], dtype=final_text_ids.dtype, device=device)
return torch.cat([start_t, final_text_ids])
def handle_pattern(x0, token_ids: Dict[str, int], generation_mode: str = 'hybrid'):
"""
Args:
x0: Token ID list of length 6
token_ids: Dictionary containing all token IDs
"""
null_token_id = token_ids['null_token_id']
im_end_token_id = token_ids['im_end_token_id']
box_start_token_id = token_ids['box_start_token_id']
box_end_token_id = token_ids['box_end_token_id']
none_token_id = token_ids['none_token_id']
coord_start_token_id = token_ids['coord_start_token_id']
coord_end_token_id = token_ids['coord_end_token_id']
ref_end_token_id = token_ids['ref_end_token_id']
x0 = x0.tolist()
if x0[0] == null_token_id:
return {
"type": "im_end",
"tokens": [im_end_token_id],
"need_switch_to_ar": False,
"is_terminal": True,
}
elif x0[0] == im_end_token_id:
return {
"type": "im_end",
"tokens": [im_end_token_id],
"need_switch_to_ar": False,
"is_terminal": True,
}
elif x0[:2] == [box_start_token_id, none_token_id]:
return {
"type": "empty_box",
"tokens": [box_start_token_id, none_token_id, box_end_token_id],
"need_switch_to_ar": False,
"is_terminal": False,
}
elif x0[0] == box_start_token_id:
coord_ix = 1
for coord in x0[1:5]:
if coord_start_token_id <= coord <= coord_end_token_id:
coord_ix += 1
else:
break
# Standard 4-coordinate bbox:
if coord_ix == 5 and x0[5] == box_end_token_id:
return {
"type": "coord_box",
"tokens": x0,
"need_switch_to_ar": False,
"is_terminal": False,
}
# Two-coordinate pointing:
# Convention: the first two coordinates are valid coord tokens, the third token is box_end.
# Remaining positions (if any) are not part of the pattern; truncate at box_end.
elif coord_ix == 3 and x0[3] == box_end_token_id:
return {
"type": "point_box",
"tokens": x0[:4],
"need_switch_to_ar": False,
"is_terminal": False,
}
else:
if generation_mode == 'fast':
# fast mode: treat as coord_box, stay in MTP
return {
"type": "coord_box",
"tokens": x0,
"need_switch_to_ar": False,
"is_terminal": False,
}
else:
# hybrid mode: error_box, switch to AR
return {
"type": "error_box",
"tokens": x0[:coord_ix],
"need_switch_to_ar": True,
"is_terminal": False,
}
else:
for i, token in enumerate(x0):
if token == null_token_id:
x0 = x0[:i]
break
if len(x0) >= 2 and x0[-1] == x0[-2] == ref_end_token_id:
x0 = x0[:-1]
return {
"type": "ref_object",
"tokens": x0,
"need_switch_to_ar": False,
"is_terminal": False,
}]