TRASER / resampler_utils /token_arrangement.py
UWGZQ's picture
Upload folder using huggingface_hub
f72dd03 verified
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple
import math
def rearrange_token(
model,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
pixel_values: Optional[torch.FloatTensor],
image_grid_thw: Optional[torch.LongTensor],
pixel_values_videos: Optional[torch.FloatTensor],
video_grid_thw: Optional[torch.LongTensor],
second_per_grid_ts: Optional[torch.Tensor],
obj_token_indices_per_sample: List[List[torch.Tensor]],
obj_traj_start_id: Optional[int] = None,
obj_traj_end_id: Optional[int] = None,
text_token_ids_per_sample: Optional[List[List[torch.Tensor]]] = None,
timestamp_token_ids_per_batch=None,
grids_per_temporal_window_per_batch=None,
labels: Optional[torch.LongTensor] = None,
IGNORE_ID: int = -100,
use_resampler: bool = True,
use_second_resampler: bool = True,
add_timestamp_token: bool = True,
):
dev = input_ids.device
B, L = input_ids.shape
cpu = torch.device("cpu")
assert text_token_ids_per_sample is not None and len(text_token_ids_per_sample) == B, \
"mode3_traj_and_text requires text_token_ids_per_sample with length B."
if add_timestamp_token:
assert timestamp_token_ids_per_batch is not None and len(timestamp_token_ids_per_batch) == B, \
"add_timestamp_token=True requires timestamp_token_ids_per_batch with length B."
assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \
"add_timestamp_token=True requires grids_per_temporal_window_per_batch with length B."
else:
assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \
"grids_per_temporal_window_per_batch is required."
tok_embed = model.get_input_embeddings()
vt_id = int(model.config.video_token_id)
vs_id = getattr(model.config, "vision_start_token_id", None)
ve_id = getattr(model.config, "vision_end_token_id", None)
pad_id = 151643
# ---- (0+) temporal window meta ----
assert video_grid_thw is not None, "video_grid_thw is required for temporal windowing"
assert video_grid_thw.shape[0] == B and video_grid_thw.shape[1] == 3, \
f"video_grid_thw should be ({B},3), got {video_grid_thw.shape}"
grid_area_batch: List[int] = []
temporal_window_size_batch = grids_per_temporal_window_per_batch
# ---- (0) Compute visual features (with grad) ----
video_embeds = None
if pixel_values_videos is not None:
_vid = model.model.get_video_features(
pixel_values_videos.type(model.model.visual.dtype), video_grid_thw
)
video_embeds = torch.cat(_vid, dim=0) if isinstance(_vid, (list, tuple)) else _vid
del pixel_values_videos, _vid
# ---- (0.1) Resamplers ----
resampler = None
resampler_num_latents = None
second_resampler = None
second_resampler_num_latents = None
if use_resampler:
if not hasattr(model, "perceiver_resampler"):
raise RuntimeError("use_resampler=True, but model.perceiver_resampler not found.")
resampler = model.perceiver_resampler
resampler_num_latents = int(resampler.n_latents)
if use_second_resampler:
if not hasattr(model, "second_perceiver_resampler"):
raise RuntimeError("use_second_resampler=True, but model.second_perceiver_resampler not found.")
second_resampler = model.second_perceiver_resampler
second_resampler_num_latents = int(second_resampler.n_latents)
# ---- (1) Position ids preparation ----
position_ids_full = None
# ---- (2) Move to CPU for sequence planning ----
attn_cpu = attention_mask.to(cpu, dtype=torch.bool)
ids_cpu = input_ids.to(cpu)
pid_cpu = None
lbls_cpu = labels.to(cpu) if labels is not None else None
eff_lens: List[int] = []
vid_idx_list: List[torch.Tensor] = []
for b in range(B):
video_grid_thw_b = video_grid_thw[b]
grid_area = (int(video_grid_thw_b[1].item()) * int(video_grid_thw_b[2].item())) // 4
grid_area_batch.append(int(grid_area))
nz = torch.nonzero(attn_cpu[b], as_tuple=False).flatten()
L_eff = int(nz[-1].item()) + 1 if nz.numel() > 0 else 0
eff_lens.append(L_eff)
if L_eff > 0:
ids_b_eff = ids_cpu[b, :L_eff]
vid_idx = torch.nonzero(ids_b_eff == vt_id, as_tuple=False).flatten()
vid_idx_list.append(vid_idx)
else:
vid_idx_list.append(torch.empty(0, dtype=torch.long))
vid_counts = [int(v.numel()) for v in vid_idx_list]
vid_offsets: List[int] = [0] * B
running = 0
for b in range(B):
vid_offsets[b] = running
running += vid_counts[b]
# ---- (3) Length planning ----
def _object_block_len(b: int, obj_i: int, sel_latent_len: int, rel_temporal_window_idx: torch.Tensor) -> int:
add = 0
if obj_traj_start_id is not None:
add += 1
tlen = int(text_token_ids_per_sample[b][obj_i].numel())
add += tlen
if vs_id is not None:
add += 1
if add_timestamp_token and timestamp_token_ids_per_batch is not None:
locs = rel_temporal_window_idx.unique()
for loc in locs:
loc_i = int(loc.item())
if loc_i < len(timestamp_token_ids_per_batch[b]):
add += int(timestamp_token_ids_per_batch[b][loc_i].numel())
else:
add += int(timestamp_token_ids_per_batch[b][-1].numel())
add += int(sel_latent_len)
# VE
if ve_id is not None:
add += 1
if obj_traj_end_id is not None:
add += 1
return add
L_new_each: List[int] = []
for b in range(B):
L_eff = eff_lens[b]
ids_b = ids_cpu[b, :L_eff]
vid_idx = vid_idx_list[b]
if L_eff == 0:
L_new_each.append(0)
continue
if vid_idx.numel() == 0:
L_new_each.append(L_eff)
continue
v_s = int(vid_idx[0].item())
v_e = int(vid_idx[-1].item())
has_vs = (vs_id is not None and v_s - 1 >= 0 and ids_b[v_s - 1].item() == vs_id)
has_ve = (ve_id is not None and v_e + 1 < L_eff and ids_b[v_e + 1].item() == ve_id)
if has_vs:
v_s -= 1
if has_ve:
v_e += 1
prefix_len = v_s
suffix_len = L_eff - (v_e + 1)
sel_lists = obj_token_indices_per_sample[b]
Nv = int(vid_idx.numel())
cur_total = 0
for i, rel in enumerate(sel_lists):
rel = rel.to(cpu, dtype=torch.long)
sel_len = int(rel.numel())
tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b]))
rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel)
nonempty_windows = int(rel_temporal_window_idx.unique().numel())
if use_second_resampler and second_resampler_num_latents is not None:
sel_len = int(second_resampler_num_latents) + int(resampler_num_latents) * nonempty_windows
else:
sel_len = int(resampler_num_latents) * nonempty_windows
cur_total += _object_block_len(b, i, sel_len, rel_temporal_window_idx)
L_new_each.append(prefix_len + cur_total + suffix_len)
Lmax = max(L_new_each) if len(L_new_each) > 0 else 0
# ---- (4) Allocate new sequence tensors on CPU and fill per-sample ----
new_input_ids_cpu = torch.full((B, Lmax), pad_id, dtype=torch.long, device=cpu)
new_attention_mask_cpu = torch.zeros((B, Lmax), dtype=torch.bool, device=cpu)
new_position_ids_cpu = torch.zeros((3, B, Lmax), dtype=torch.int32, device=cpu)
new_labels_cpu = None
if labels is not None:
new_labels_cpu = torch.full((B, Lmax), IGNORE_ID, dtype=torch.long, device=cpu)
rows_for_video: List[torch.Tensor] = [torch.empty(0, dtype=torch.long) for _ in range(B)]
batched_obj_rows: List[torch.Tensor] = []
batched_obj_pos: List[torch.Tensor] = []
batched_obj_bids: List[int] = []
batched_obj_lens: List[int] = []
batched_second_rows: List[torch.Tensor] = []
batched_second_pos: List[torch.Tensor] = []
batched_second_bids: List[int] = []
batched_second_oids: List[int] = []
def _text_pos_block(start_scalar: int, length: int, dtype=torch.int32) -> torch.Tensor:
"""Create 1D-linear positions replicated across 3 RoPE dims."""
if length <= 0:
return torch.empty(3, 0, dtype=dtype, device=cpu)
ar = torch.arange(start_scalar, start_scalar + length, device=cpu, dtype=dtype)
return torch.stack([ar, ar, ar], dim=0)
for b in range(B):
L_eff = eff_lens[b]
if L_eff == 0:
continue
ids_b = ids_cpu[b, :L_eff]
msk_b = attn_cpu[b, :L_eff]
labs_b = lbls_cpu[b, :L_eff] if lbls_cpu is not None else None
vid_idx = vid_idx_list[b]
dst = 0
if vid_idx.numel() == 0:
new_input_ids_cpu[b, :L_eff] = ids_b
new_attention_mask_cpu[b, :L_eff] = msk_b
if new_labels_cpu is not None and labs_b is not None:
new_labels_cpu[b, :L_eff] = labs_b
new_position_ids_cpu[:, b, :L_eff] = _text_pos_block(0, L_eff, dtype=torch.int32)
continue
v_s = int(vid_idx[0].item())
v_e = int(vid_idx[-1].item())
has_vs = (vs_id is not None and v_s - 1 >= 0 and ids_b[v_s - 1].item() == vs_id)
has_ve = (ve_id is not None and v_e + 1 < L_eff and ids_b[v_e + 1].item() == ve_id)
if has_vs:
v_s -= 1
if has_ve:
v_e += 1
prefix_len = v_s
suffix_len = L_eff - (v_e + 1)
if prefix_len > 0:
new_input_ids_cpu[b, dst:dst + prefix_len] = ids_b[:prefix_len]
new_attention_mask_cpu[b, dst:dst + prefix_len] = msk_b[:prefix_len]
if new_labels_cpu is not None and labs_b is not None:
new_labels_cpu[b, dst:dst + prefix_len] = labs_b[:prefix_len]
new_position_ids_cpu[:, b, dst:dst + prefix_len] = _text_pos_block(dst, prefix_len, dtype=torch.int32)
dst += prefix_len
Nv = int(vid_idx.numel())
pos2rank = torch.full((L_eff,), -1, dtype=torch.long, device=cpu)
if Nv > 0:
pos2rank[vid_idx] = torch.arange(Nv, dtype=torch.long, device=cpu)
vid_offset = int(vid_offsets[b])
sel_lists = obj_token_indices_per_sample[b]
for i, rel in enumerate(sel_lists):
rel = rel.to(cpu, dtype=torch.long)
if rel.numel() > 0:
rel.clamp_(0, Nv - 1)
g = vid_idx.index_select(0, rel) if (Nv > 0 and rel.numel() > 0) else torch.empty(0, dtype=torch.long, device=cpu)
# (1) <obj_traj_start> (optional)
if obj_traj_start_id is not None:
new_input_ids_cpu[b, dst] = int(obj_traj_start_id)
new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst] = IGNORE_ID
new_attention_mask_cpu[b, dst] = True
dst += 1
# (2) text tokens (required)
txt_ids = text_token_ids_per_sample[b][i].to(cpu, dtype=torch.long)
k = int(txt_ids.numel())
if k > 0:
new_input_ids_cpu[b, dst:dst + k] = txt_ids
new_position_ids_cpu[:, b, dst:dst + k] = _text_pos_block(dst, k, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst:dst + k] = IGNORE_ID
new_attention_mask_cpu[b, dst:dst + k] = True
dst += k
# (3) <VS> (optional)
if vs_id is not None:
new_input_ids_cpu[b, dst] = int(vs_id)
new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst] = IGNORE_ID
new_attention_mask_cpu[b, dst] = True
dst += 1
# (4) video tokens
if g.numel() > 0:
tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b]))
rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel)
W_eff = int(rel_temporal_window_idx.max().item()) + 1 if rel_temporal_window_idx.numel() > 0 else 0
all_rows_list = []
for w in range(W_eff):
m_w = (rel_temporal_window_idx == w)
if not torch.any(m_w):
all_rows_list.append(torch.empty(0, dtype=torch.long, device=cpu))
continue
rel_w = rel[m_w]
rows_w = rel_w + vid_offset
all_rows_list.append(rows_w)
# second resampler: global object summary
if use_second_resampler and second_resampler is not None:
rows_all = torch.cat([x for x in all_rows_list if x.numel() > 0], dim=0) if any(x.numel() > 0 for x in all_rows_list) \
else torch.empty(0, dtype=torch.long, device=cpu)
if rows_all.numel() > 0:
R2 = int(second_resampler_num_latents)
new_input_ids_cpu[b, dst:dst + R2] = int(vt_id)
new_position_ids_cpu[:, b, dst:dst + R2] = _text_pos_block( dst, R2, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst:dst + R2] = IGNORE_ID
new_attention_mask_cpu[b, dst:dst + R2] = True
pos_idx2 = torch.arange(dst, dst + R2, dtype=torch.long, device=cpu)
batched_second_rows.append(rows_all)
batched_second_pos.append(pos_idx2)
batched_second_bids.append(b)
batched_second_oids.append(i)
dst += R2
R = int(resampler_num_latents)
for w in range(W_eff):
m_w = (rel_temporal_window_idx == w)
if not torch.any(m_w):
continue
# timestamp tokens (text-only; NOT injected into resampler)
if add_timestamp_token and (timestamp_token_ids_per_batch is not None):
loc = w
if loc < len(timestamp_token_ids_per_batch[b]):
ts_ids = timestamp_token_ids_per_batch[b][loc].to(cpu, dtype=torch.long)
else:
ts_ids = timestamp_token_ids_per_batch[b][-1].to(cpu, dtype=torch.long)
kt = int(ts_ids.numel())
assert kt > 0, "Timestamp token ids should not be empty."
new_input_ids_cpu[b, dst:dst + kt] = ts_ids
new_position_ids_cpu[:, b, dst:dst + kt] = _text_pos_block(dst, kt, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst:dst + kt] = IGNORE_ID
new_attention_mask_cpu[b, dst:dst + kt] = True
dst += kt
new_input_ids_cpu[b, dst:dst + R] = int(vt_id)
new_position_ids_cpu[:, b, dst:dst + R] = _text_pos_block(dst, R, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst:dst + R] = IGNORE_ID
new_attention_mask_cpu[b, dst:dst + R] = True
rel_w = rel[m_w]
rows_w = rel_w + vid_offset
pos_idx = torch.arange(dst, dst + R, dtype=torch.long, device=cpu)
batched_obj_rows.append(rows_w)
batched_obj_pos.append(pos_idx)
batched_obj_bids.append(b)
batched_obj_lens.append(int(rows_w.numel()))
dst += R
# (5) <VE> (optional)
if ve_id is not None:
new_input_ids_cpu[b, dst] = int(ve_id)
new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst] = IGNORE_ID
new_attention_mask_cpu[b, dst] = True
dst += 1
# (6) <obj_traj_end> (optional)
if obj_traj_end_id is not None:
new_input_ids_cpu[b, dst] = int(obj_traj_end_id)
new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
if new_labels_cpu is not None:
new_labels_cpu[b, dst] = IGNORE_ID
new_attention_mask_cpu[b, dst] = True
dst += 1
# suffix
if suffix_len > 0:
src_lo = v_e + 1
src_hi = L_eff
seg = src_hi - src_lo
new_input_ids_cpu[b, dst:dst + seg] = ids_b[src_lo:src_hi]
new_attention_mask_cpu[b, dst:dst + seg] = msk_b[src_lo:src_hi]
if new_labels_cpu is not None and labs_b is not None:
new_labels_cpu[b, dst:dst + seg] = labs_b[src_lo:src_hi]
new_position_ids_cpu[:, b, dst:dst + seg] = _text_pos_block(dst, seg, dtype=torch.int32)
dst += seg
assert dst == L_new_each[b], f"sample {b}: dst={dst}, L_new={L_new_each[b]}"
# ---- (5) Move back to device, build inputs_embeds, and paste visual features ----
new_input_ids = new_input_ids_cpu.to(dev, non_blocking=True)
new_position_ids = new_position_ids_cpu.to(dev, non_blocking=True)
new_attention_mask = new_attention_mask_cpu.to(dev, non_blocking=True)
new_labels = None if new_labels_cpu is None else new_labels_cpu.to(dev, non_blocking=True)
base = tok_embed(new_input_ids)
new_inputs_embeds = base.clone()
# ---- (5.1) second resampler: object-level global summary ----
if use_resampler and use_second_resampler and len(batched_second_rows) > 0:
if video_embeds is None:
raise RuntimeError("use_second_resampler=True but video_embeds is None.")
dev_emb = video_embeds.device
dtype_emb = video_embeds.dtype
D = video_embeds.shape[-1]
N_obj2 = len(batched_second_rows)
seqs2 = []
lens2 = []
for rows_all in batched_second_rows:
if rows_all.numel() == 0:
seqs2.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb))
lens2.append(0)
else:
seqs2.append(video_embeds.index_select(0, rows_all.to(dev_emb)))
lens2.append(int(rows_all.numel()))
x2 = torch.nn.utils.rnn.pad_sequence(seqs2, batch_first=True) if len(seqs2) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb)
L2_max = x2.size(1) if x2.numel() > 0 else 0
lens2_t = torch.tensor(lens2, device=dev_emb, dtype=torch.long) if len(lens2) > 0 else torch.zeros(0, device=dev_emb, dtype=torch.long)
ar2 = torch.arange(L2_max, device=dev_emb).unsqueeze(0) if L2_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long)
mask2 = (ar2 < lens2_t.unsqueeze(1)) if L2_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool)
y2 = second_resampler(x2, attention_mask=mask2)
y2 = y2.to(new_inputs_embeds.dtype)
for j in range(N_obj2):
b_cur = batched_second_bids[j]
pos2 = batched_second_pos[j].to(dev)
new_inputs_embeds[b_cur, pos2] = y2[j]
# ---- (5.2) main resampler: temporal resampler----
if use_resampler and len(batched_obj_rows) > 0:
if video_embeds is None:
raise RuntimeError("use_resampler=True but video_embeds is None.")
dev_emb = video_embeds.device
dtype_emb = video_embeds.dtype
D = video_embeds.shape[-1]
N_obj = len(batched_obj_rows)
lens = torch.tensor(batched_obj_lens, device=dev_emb, dtype=torch.long)
L_max = int(lens.max().item()) if lens.numel() > 0 else 0
seqs = []
for rows in batched_obj_rows:
if rows.numel() == 0:
seqs.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb))
else:
seqs.append(video_embeds.index_select(0, rows.to(dev_emb)))
x = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True) if len(seqs) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb)
ar = torch.arange(L_max, device=dev_emb).unsqueeze(0) if L_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long)
mask = (ar < lens.unsqueeze(1)) if L_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool)
y = resampler(x, attention_mask=mask)
y = y.to(new_inputs_embeds.dtype)
per_b_indices: List[List[int]] = [[] for _ in range(B)]
for i in range(N_obj):
per_b_indices[batched_obj_bids[i]].append(i)
for b in range(B):
if not per_b_indices[b]:
continue
pos_list = []
emb_list = []
for i in per_b_indices[b]:
pos_list.append(batched_obj_pos[i].to(dev))
emb_list.append(y[i])
pos_b = torch.cat(pos_list, dim=0)
emb_b = torch.cat(emb_list, dim=0)
new_inputs_embeds[b, pos_b] = emb_b
# ---- (6) rope_deltas / cache_position ----
maxpos = new_position_ids.max(dim=0)[0].max(dim=1, keepdim=True)[0]
rope_deltas = (maxpos + 1 - new_inputs_embeds.shape[1]).to(dtype=torch.long, device=dev)
cache_position = torch.arange(new_inputs_embeds.shape[1], device=dev, dtype=torch.int32)
return new_inputs_embeds, new_position_ids, new_attention_mask, rope_deltas, cache_position, new_input_ids, new_labels