UWGZQ commited on
Commit
f72dd03
·
verified ·
1 Parent(s): 611ab28

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference.py +0 -1
  2. resampler_utils/token_arrangement.py +111 -235
inference.py CHANGED
@@ -160,7 +160,6 @@ def run_single_video(model, processor, video_path, mask_path, out_dir, device, a
160
  text_token_ids_per_sample=text_token_ids_per_sample,
161
  timestamp_token_ids_per_batch=timestamp_token_ids_per_batch,
162
  grids_per_temporal_window_per_batch=grids_per_window_batch,
163
- use_resampler=True
164
  )
165
 
166
  gen_out = model.generate(
 
160
  text_token_ids_per_sample=text_token_ids_per_sample,
161
  timestamp_token_ids_per_batch=timestamp_token_ids_per_batch,
162
  grids_per_temporal_window_per_batch=grids_per_window_batch,
 
163
  )
164
 
165
  gen_out = model.generate(
resampler_utils/token_arrangement.py CHANGED
@@ -6,49 +6,31 @@ import math
6
 
7
  def rearrange_token(
8
  model,
9
- input_ids: torch.LongTensor, # [B, L]
10
- attention_mask: torch.LongTensor, # [B, L]
11
- pixel_values: Optional[torch.FloatTensor], # unused here (image path kept for API compatibility)
12
- image_grid_thw: Optional[torch.LongTensor], # unused here (image path kept for API compatibility)
13
- pixel_values_videos: Optional[torch.FloatTensor], # may be None
14
- video_grid_thw: Optional[torch.LongTensor], # may be None
15
- second_per_grid_ts: Optional[torch.Tensor], # may be None
16
-
17
- # Per-sample list of objects; each object is a 1D LongTensor of relative video-token indices (in the original video token stream)
18
  obj_token_indices_per_sample: List[List[torch.Tensor]],
19
 
20
- # Only mode3_traj_and_text is kept:
21
  obj_traj_start_id: Optional[int] = None,
22
  obj_traj_end_id: Optional[int] = None,
23
 
24
- # Required: List[sample][object] -> 1D LongTensor(ids)
25
  text_token_ids_per_sample: Optional[List[List[torch.Tensor]]] = None,
26
 
27
- timestamp_token_ids_per_batch=None, # List[sample][1D LongTensor(ids)]
28
- grids_per_temporal_window_per_batch=None, # List[sample] number of grids per temporal window
29
 
30
  labels: Optional[torch.LongTensor] = None,
31
  IGNORE_ID: int = -100,
32
 
33
- use_resampler: bool = True, # True → per-object resampling + linear (1D) positions
34
  use_second_resampler: bool = True,
35
- add_timestamp_token: bool = True, # whether to add timestamp token for each object window
36
  ):
37
- """
38
- Fixed simplifications:
39
- - insert_where: only "in_order" (no argument kept)
40
- - insertion_mode: only "mode3_traj_and_text"
41
- - perceiver_injection: only "visuals" (no time tokens injected into resampler)
42
-
43
- Returns:
44
- new_inputs_embeds: [B, Lmax, D]
45
- new_position_ids: [3, B, Lmax] (int32)
46
- new_attention_mask: [B, Lmax] (bool)
47
- rope_deltas: [B, 1] (long)
48
- cache_position: [Lmax] (int32)
49
- new_input_ids: [B, Lmax] (long)
50
- new_labels: [B, Lmax] or None (long)
51
- """
52
  dev = input_ids.device
53
  B, L = input_ids.shape
54
  cpu = torch.device("cpu")
@@ -62,7 +44,6 @@ def rearrange_token(
62
  assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \
63
  "add_timestamp_token=True requires grids_per_temporal_window_per_batch with length B."
64
  else:
65
- # still needed for window indexing if use_resampler path uses temporal windows
66
  assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \
67
  "grids_per_temporal_window_per_batch is required."
68
 
@@ -70,14 +51,14 @@ def rearrange_token(
70
  vt_id = int(model.config.video_token_id)
71
  vs_id = getattr(model.config, "vision_start_token_id", None)
72
  ve_id = getattr(model.config, "vision_end_token_id", None)
73
- pad_id = 151643 # align with original implementation
74
 
75
  # ---- (0+) temporal window meta ----
76
  assert video_grid_thw is not None, "video_grid_thw is required for temporal windowing"
77
  assert video_grid_thw.shape[0] == B and video_grid_thw.shape[1] == 3, \
78
  f"video_grid_thw should be ({B},3), got {video_grid_thw.shape}"
79
 
80
- grid_area_batch: List[int] = [] # per-sample spatial token count (H*W/4)
81
  temporal_window_size_batch = grids_per_temporal_window_per_batch
82
 
83
  # ---- (0) Compute visual features (with grad) ----
@@ -86,7 +67,7 @@ def rearrange_token(
86
  _vid = model.model.get_video_features(
87
  pixel_values_videos.type(model.model.visual.dtype), video_grid_thw
88
  )
89
- video_embeds = torch.cat(_vid, dim=0) if isinstance(_vid, (list, tuple)) else _vid # [N_vid, D]
90
  del pixel_values_videos, _vid
91
 
92
  # ---- (0.1) Resamplers ----
@@ -106,30 +87,18 @@ def rearrange_token(
106
  second_resampler_num_latents = int(second_resampler.n_latents)
107
 
108
  # ---- (1) Position ids preparation ----
109
- need_3d_rope = (not use_resampler)
110
- if need_3d_rope:
111
- with torch.no_grad():
112
- position_ids_full, _ = model.model.get_rope_index(
113
- input_ids=input_ids,
114
- image_grid_thw=image_grid_thw,
115
- video_grid_thw=video_grid_thw,
116
- second_per_grid_ts=second_per_grid_ts,
117
- attention_mask=attention_mask,
118
- ).to(cpu) # (3, B, L)
119
- else:
120
- position_ids_full = None
121
 
122
  # ---- (2) Move to CPU for sequence planning ----
123
  attn_cpu = attention_mask.to(cpu, dtype=torch.bool)
124
  ids_cpu = input_ids.to(cpu)
125
- pid_cpu = position_ids_full.to(cpu, dtype=torch.int32) if need_3d_rope else None
126
  lbls_cpu = labels.to(cpu) if labels is not None else None
127
 
128
  eff_lens: List[int] = []
129
  vid_idx_list: List[torch.Tensor] = []
130
  for b in range(B):
131
  video_grid_thw_b = video_grid_thw[b]
132
- # H*W/4 as integer
133
  grid_area = (int(video_grid_thw_b[1].item()) * int(video_grid_thw_b[2].item())) // 4
134
  grid_area_batch.append(int(grid_area))
135
 
@@ -144,7 +113,6 @@ def rearrange_token(
144
  else:
145
  vid_idx_list.append(torch.empty(0, dtype=torch.long))
146
 
147
- # ---- Global offsets into concatenated video_embeds for each sample ----
148
  vid_counts = [int(v.numel()) for v in vid_idx_list]
149
  vid_offsets: List[int] = [0] * B
150
  running = 0
@@ -154,26 +122,17 @@ def rearrange_token(
154
 
155
  # ---- (3) Length planning ----
156
  def _object_block_len(b: int, obj_i: int, sel_latent_len: int, rel_temporal_window_idx: torch.Tensor) -> int:
157
- """
158
- mode3_traj_and_text block length:
159
- [<traj_start>?] + [text] + [<VS>?] + [<ts>* + <vt_latents>*] + [<VE>?] + [<traj_end>?]
160
- where <ts>* and <vt_latents>* repeat per non-empty temporal window (resampler path),
161
- or raw selected video tokens (non-resampler path).
162
- """
163
  add = 0
164
 
165
  if obj_traj_start_id is not None:
166
  add += 1
167
 
168
- # text
169
  tlen = int(text_token_ids_per_sample[b][obj_i].numel())
170
  add += tlen
171
 
172
- # VS
173
  if vs_id is not None:
174
  add += 1
175
 
176
- # timestamps per unique window (if enabled)
177
  if add_timestamp_token and timestamp_token_ids_per_batch is not None:
178
  locs = rel_temporal_window_idx.unique()
179
  for loc in locs:
@@ -183,7 +142,6 @@ def rearrange_token(
183
  else:
184
  add += int(timestamp_token_ids_per_batch[b][-1].numel())
185
 
186
- # visual placeholder length (either resampled latents or raw selected tokens)
187
  add += int(sel_latent_len)
188
 
189
  # VE
@@ -230,19 +188,14 @@ def rearrange_token(
230
  rel = rel.to(cpu, dtype=torch.long)
231
  sel_len = int(rel.numel())
232
 
233
- if use_resampler:
234
- tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b]))
235
- rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel)
236
- nonempty_windows = int(rel_temporal_window_idx.unique().numel())
237
 
238
- if use_second_resampler and second_resampler_num_latents is not None:
239
- sel_len = int(second_resampler_num_latents) + int(resampler_num_latents) * nonempty_windows
240
- else:
241
- sel_len = int(resampler_num_latents) * nonempty_windows
242
  else:
243
- # Non-resampler: keep raw selected video tokens count
244
- tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b]))
245
- rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel)
246
 
247
  cur_total += _object_block_len(b, i, sel_len, rel_temporal_window_idx)
248
 
@@ -260,10 +213,10 @@ def rearrange_token(
260
 
261
  rows_for_video: List[torch.Tensor] = [torch.empty(0, dtype=torch.long) for _ in range(B)]
262
 
263
- batched_obj_rows: List[torch.Tensor] = [] # each: rows into video_embeds (visual-only)
264
- batched_obj_pos: List[torch.Tensor] = [] # each: destination positions [R]
265
  batched_obj_bids: List[int] = []
266
- batched_obj_lens: List[int] = [] # visual token lengths per (object-window)
267
 
268
  batched_second_rows: List[torch.Tensor] = []
269
  batched_second_pos: List[torch.Tensor] = []
@@ -289,16 +242,12 @@ def rearrange_token(
289
 
290
  dst = 0
291
 
292
- # No video tokens: copy through
293
  if vid_idx.numel() == 0:
294
  new_input_ids_cpu[b, :L_eff] = ids_b
295
  new_attention_mask_cpu[b, :L_eff] = msk_b
296
  if new_labels_cpu is not None and labs_b is not None:
297
  new_labels_cpu[b, :L_eff] = labs_b
298
- if need_3d_rope:
299
- new_position_ids_cpu[:, b, :L_eff] = pid_cpu[:, b, :L_eff]
300
- else:
301
- new_position_ids_cpu[:, b, :L_eff] = _text_pos_block(0, L_eff, dtype=torch.int32)
302
  continue
303
 
304
  v_s = int(vid_idx[0].item())
@@ -313,34 +262,14 @@ def rearrange_token(
313
  prefix_len = v_s
314
  suffix_len = L_eff - (v_e + 1)
315
 
316
- if need_3d_rope:
317
- pid_b = pid_cpu[:, b, :L_eff]
318
- pos_scalar = pid_b.max(dim=0).values
319
- first_video_scalar = int(pos_scalar[v_s + (1 if has_vs else 0)].item())
320
- last_video_scalar = int(pos_scalar[v_e - (1 if has_ve else 0)].item())
321
- vs_scalar = int(pos_scalar[v_s].item()) if has_vs else None
322
-
323
- min_video_scalar_base = int(first_video_scalar)
324
- max_video_scalar_base = int(last_video_scalar)
325
-
326
- # prefix
327
  if prefix_len > 0:
328
  new_input_ids_cpu[b, dst:dst + prefix_len] = ids_b[:prefix_len]
329
  new_attention_mask_cpu[b, dst:dst + prefix_len] = msk_b[:prefix_len]
330
  if new_labels_cpu is not None and labs_b is not None:
331
  new_labels_cpu[b, dst:dst + prefix_len] = labs_b[:prefix_len]
332
- if need_3d_rope:
333
- new_position_ids_cpu[:, b, dst:dst + prefix_len] = pid_b[:, :prefix_len]
334
- else:
335
- new_position_ids_cpu[:, b, dst:dst + prefix_len] = _text_pos_block(dst, prefix_len, dtype=torch.int32)
336
  dst += prefix_len
337
 
338
- # in_order only:
339
- if need_3d_rope:
340
- cursor = int(vs_scalar) if has_vs else int(first_video_scalar)
341
- else:
342
- cursor = dst
343
-
344
  Nv = int(vid_idx.numel())
345
  pos2rank = torch.full((L_eff,), -1, dtype=torch.long, device=cpu)
346
  if Nv > 0:
@@ -359,170 +288,128 @@ def rearrange_token(
359
  # (1) <obj_traj_start> (optional)
360
  if obj_traj_start_id is not None:
361
  new_input_ids_cpu[b, dst] = int(obj_traj_start_id)
362
- new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(cursor if need_3d_rope else dst, 1, dtype=torch.int32)
363
  if new_labels_cpu is not None:
364
  new_labels_cpu[b, dst] = IGNORE_ID
365
  new_attention_mask_cpu[b, dst] = True
366
  dst += 1
367
- if need_3d_rope:
368
- cursor += 1
369
 
370
  # (2) text tokens (required)
371
  txt_ids = text_token_ids_per_sample[b][i].to(cpu, dtype=torch.long)
372
  k = int(txt_ids.numel())
373
  if k > 0:
374
  new_input_ids_cpu[b, dst:dst + k] = txt_ids
375
- new_position_ids_cpu[:, b, dst:dst + k] = _text_pos_block(cursor if need_3d_rope else dst, k, dtype=torch.int32)
376
  if new_labels_cpu is not None:
377
  new_labels_cpu[b, dst:dst + k] = IGNORE_ID
378
  new_attention_mask_cpu[b, dst:dst + k] = True
379
  dst += k
380
- if need_3d_rope:
381
- cursor += k
382
 
383
  # (3) <VS> (optional)
384
  if vs_id is not None:
385
  new_input_ids_cpu[b, dst] = int(vs_id)
386
- new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(cursor if need_3d_rope else dst, 1, dtype=torch.int32)
387
  if new_labels_cpu is not None:
388
  new_labels_cpu[b, dst] = IGNORE_ID
389
  new_attention_mask_cpu[b, dst] = True
390
  dst += 1
391
- if need_3d_rope:
392
- cursor += 1
393
 
394
  # (4) video tokens
395
  if g.numel() > 0:
396
- if use_resampler:
397
- tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b]))
398
- rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel)
399
-
400
- # Loop only over windows that actually appear in rel (robust)
401
- W_eff = int(rel_temporal_window_idx.max().item()) + 1 if rel_temporal_window_idx.numel() > 0 else 0
402
-
403
- all_rows_list = []
404
- for w in range(W_eff):
405
- m_w = (rel_temporal_window_idx == w)
406
- if not torch.any(m_w):
407
- all_rows_list.append(torch.empty(0, dtype=torch.long, device=cpu))
408
- continue
409
- rel_w = rel[m_w]
410
- rows_w = rel_w + vid_offset
411
- all_rows_list.append(rows_w)
412
-
413
- # second resampler: global object summary
414
- if use_second_resampler and second_resampler is not None:
415
- rows_all = torch.cat([x for x in all_rows_list if x.numel() > 0], dim=0) if any(x.numel() > 0 for x in all_rows_list) \
416
- else torch.empty(0, dtype=torch.long, device=cpu)
417
-
418
- if rows_all.numel() > 0:
419
- R2 = int(second_resampler_num_latents)
420
- new_input_ids_cpu[b, dst:dst + R2] = int(vt_id)
421
- new_position_ids_cpu[:, b, dst:dst + R2] = _text_pos_block(cursor if need_3d_rope else dst, R2, dtype=torch.int32)
422
- if new_labels_cpu is not None:
423
- new_labels_cpu[b, dst:dst + R2] = IGNORE_ID
424
- new_attention_mask_cpu[b, dst:dst + R2] = True
425
-
426
- pos_idx2 = torch.arange(dst, dst + R2, dtype=torch.long, device=cpu)
427
- batched_second_rows.append(rows_all)
428
- batched_second_pos.append(pos_idx2)
429
- batched_second_bids.append(b)
430
- batched_second_oids.append(i)
431
-
432
- dst += R2
433
- if need_3d_rope:
434
- cursor += R2
435
-
436
- R = int(resampler_num_latents)
437
-
438
- for w in range(W_eff):
439
- m_w = (rel_temporal_window_idx == w)
440
- if not torch.any(m_w):
441
- continue
442
-
443
- # timestamp tokens (text-only; NOT injected into resampler)
444
- if add_timestamp_token and (timestamp_token_ids_per_batch is not None):
445
- loc = w
446
- if loc < len(timestamp_token_ids_per_batch[b]):
447
- ts_ids = timestamp_token_ids_per_batch[b][loc].to(cpu, dtype=torch.long)
448
- else:
449
- ts_ids = timestamp_token_ids_per_batch[b][-1].to(cpu, dtype=torch.long)
450
- kt = int(ts_ids.numel())
451
- assert kt > 0, "Timestamp token ids should not be empty."
452
-
453
- new_input_ids_cpu[b, dst:dst + kt] = ts_ids
454
- new_position_ids_cpu[:, b, dst:dst + kt] = _text_pos_block(cursor if need_3d_rope else dst, kt, dtype=torch.int32)
455
- if new_labels_cpu is not None:
456
- new_labels_cpu[b, dst:dst + kt] = IGNORE_ID
457
- new_attention_mask_cpu[b, dst:dst + kt] = True
458
- dst += kt
459
- if need_3d_rope:
460
- cursor += kt
461
-
462
- # reserve R vt slots for resampled latents
463
- new_input_ids_cpu[b, dst:dst + R] = int(vt_id)
464
- new_position_ids_cpu[:, b, dst:dst + R] = _text_pos_block(cursor if need_3d_rope else dst, R, dtype=torch.int32)
465
- if new_labels_cpu is not None:
466
- new_labels_cpu[b, dst:dst + R] = IGNORE_ID
467
- new_attention_mask_cpu[b, dst:dst + R] = True
468
-
469
- rel_w = rel[m_w]
470
- rows_w = rel_w + vid_offset
471
- pos_idx = torch.arange(dst, dst + R, dtype=torch.long, device=cpu)
472
 
473
- batched_obj_rows.append(rows_w)
474
- batched_obj_pos.append(pos_idx)
475
- batched_obj_bids.append(b)
476
- batched_obj_lens.append(int(rows_w.numel())) # visuals-only
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
- dst += R
479
- if need_3d_rope:
480
- cursor += R
481
 
482
- else:
483
- # Non-resampler: 3D RoPE positions for selected raw video tokens
484
- assert need_3d_rope, "Non-resampler path requires 3D RoPE positions."
485
- pid_vid = pid_b.index_select(1, g) # (3, Lv_sel)
486
-
487
- # in_order only: shift selected pid by delta
488
- delta = int(cursor - min_video_scalar_base)
489
- if delta != 0:
490
- pid_vid = pid_vid + delta
491
- cursor = max_video_scalar_base + delta + 1
492
-
493
- Lv_sel = int(g.numel())
494
- new_input_ids_cpu[b, dst:dst + Lv_sel] = int(vt_id)
495
- new_position_ids_cpu[:, b, dst:dst + Lv_sel] = pid_vid
496
  if new_labels_cpu is not None:
497
- new_labels_cpu[b, dst:dst + Lv_sel] = IGNORE_ID
498
- new_attention_mask_cpu[b, dst:dst + Lv_sel] = True
 
 
 
 
499
 
500
- ranks = pos2rank.index_select(0, g)
501
- rows = ranks + vid_offset
502
- rows_for_video[b] = torch.cat([rows_for_video[b], rows], dim=0)
503
- dst += Lv_sel
504
 
 
505
  # (5) <VE> (optional)
506
  if ve_id is not None:
507
  new_input_ids_cpu[b, dst] = int(ve_id)
508
- new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(cursor if need_3d_rope else dst, 1, dtype=torch.int32)
509
  if new_labels_cpu is not None:
510
  new_labels_cpu[b, dst] = IGNORE_ID
511
  new_attention_mask_cpu[b, dst] = True
512
  dst += 1
513
- if need_3d_rope:
514
- cursor += 1
515
 
516
  # (6) <obj_traj_end> (optional)
517
  if obj_traj_end_id is not None:
518
  new_input_ids_cpu[b, dst] = int(obj_traj_end_id)
519
- new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(cursor if need_3d_rope else dst, 1, dtype=torch.int32)
520
  if new_labels_cpu is not None:
521
  new_labels_cpu[b, dst] = IGNORE_ID
522
  new_attention_mask_cpu[b, dst] = True
523
  dst += 1
524
- if need_3d_rope:
525
- cursor += 1
526
 
527
  # suffix
528
  if suffix_len > 0:
@@ -533,7 +420,7 @@ def rearrange_token(
533
  new_attention_mask_cpu[b, dst:dst + seg] = msk_b[src_lo:src_hi]
534
  if new_labels_cpu is not None and labs_b is not None:
535
  new_labels_cpu[b, dst:dst + seg] = labs_b[src_lo:src_hi]
536
- new_position_ids_cpu[:, b, dst:dst + seg] = _text_pos_block(dst, seg, dtype=torch.int32) if not need_3d_rope else _text_pos_block(cursor, seg, dtype=torch.int32)
537
  dst += seg
538
 
539
  assert dst == L_new_each[b], f"sample {b}: dst={dst}, L_new={L_new_each[b]}"
@@ -547,17 +434,6 @@ def rearrange_token(
547
  base = tok_embed(new_input_ids)
548
  new_inputs_embeds = base.clone()
549
 
550
- # Non-resampler: copy raw video features at vt positions
551
- if (video_embeds is not None) and (not use_resampler) and any(r.numel() > 0 for r in rows_for_video):
552
- vemb = video_embeds.to(dev, dtype=new_inputs_embeds.dtype, non_blocking=True)
553
- for b in range(B):
554
- rows = rows_for_video[b]
555
- if rows.numel() == 0:
556
- continue
557
- vt_pos = torch.nonzero(new_input_ids[b] == vt_id, as_tuple=False).flatten()
558
- assert vt_pos.numel() == rows.numel(), f"video rows mismatch for sample {b}"
559
- new_inputs_embeds[b].index_copy_(0, vt_pos.to(dev), vemb.index_select(0, rows.to(dev)))
560
-
561
  # ---- (5.1) second resampler: object-level global summary ----
562
  if use_resampler and use_second_resampler and len(batched_second_rows) > 0:
563
  if video_embeds is None:
@@ -582,7 +458,7 @@ def rearrange_token(
582
  ar2 = torch.arange(L2_max, device=dev_emb).unsqueeze(0) if L2_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long)
583
  mask2 = (ar2 < lens2_t.unsqueeze(1)) if L2_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool)
584
 
585
- y2 = second_resampler(x2, attention_mask=mask2) # [N_obj2, R2, D]
586
  y2 = y2.to(new_inputs_embeds.dtype)
587
 
588
  for j in range(N_obj2):
@@ -590,7 +466,7 @@ def rearrange_token(
590
  pos2 = batched_second_pos[j].to(dev)
591
  new_inputs_embeds[b_cur, pos2] = y2[j]
592
 
593
- # ---- (5.2) main resampler: visuals-only ----
594
  if use_resampler and len(batched_obj_rows) > 0:
595
  if video_embeds is None:
596
  raise RuntimeError("use_resampler=True but video_embeds is None.")
@@ -599,7 +475,7 @@ def rearrange_token(
599
  D = video_embeds.shape[-1]
600
 
601
  N_obj = len(batched_obj_rows)
602
- lens = torch.tensor(batched_obj_lens, device=dev_emb, dtype=torch.long) # [N_obj]
603
  L_max = int(lens.max().item()) if lens.numel() > 0 else 0
604
 
605
  seqs = []
@@ -607,13 +483,13 @@ def rearrange_token(
607
  if rows.numel() == 0:
608
  seqs.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb))
609
  else:
610
- seqs.append(video_embeds.index_select(0, rows.to(dev_emb))) # [Lv_sel, D]
611
  x = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True) if len(seqs) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb)
612
 
613
  ar = torch.arange(L_max, device=dev_emb).unsqueeze(0) if L_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long)
614
  mask = (ar < lens.unsqueeze(1)) if L_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool)
615
 
616
- y = resampler(x, attention_mask=mask) # [N_obj, R, D]
617
  y = y.to(new_inputs_embeds.dtype)
618
 
619
  per_b_indices: List[List[int]] = [[] for _ in range(B)]
@@ -633,7 +509,7 @@ def rearrange_token(
633
  new_inputs_embeds[b, pos_b] = emb_b
634
 
635
  # ---- (6) rope_deltas / cache_position ----
636
- maxpos = new_position_ids.max(dim=0)[0].max(dim=1, keepdim=True)[0] # [B,1]
637
  rope_deltas = (maxpos + 1 - new_inputs_embeds.shape[1]).to(dtype=torch.long, device=dev)
638
  cache_position = torch.arange(new_inputs_embeds.shape[1], device=dev, dtype=torch.int32)
639
 
 
6
 
7
  def rearrange_token(
8
  model,
9
+ input_ids: torch.LongTensor,
10
+ attention_mask: torch.LongTensor,
11
+ pixel_values: Optional[torch.FloatTensor],
12
+ image_grid_thw: Optional[torch.LongTensor],
13
+ pixel_values_videos: Optional[torch.FloatTensor],
14
+ video_grid_thw: Optional[torch.LongTensor],
15
+ second_per_grid_ts: Optional[torch.Tensor],
16
+
 
17
  obj_token_indices_per_sample: List[List[torch.Tensor]],
18
 
 
19
  obj_traj_start_id: Optional[int] = None,
20
  obj_traj_end_id: Optional[int] = None,
21
 
 
22
  text_token_ids_per_sample: Optional[List[List[torch.Tensor]]] = None,
23
 
24
+ timestamp_token_ids_per_batch=None,
25
+ grids_per_temporal_window_per_batch=None,
26
 
27
  labels: Optional[torch.LongTensor] = None,
28
  IGNORE_ID: int = -100,
29
 
30
+ use_resampler: bool = True,
31
  use_second_resampler: bool = True,
32
+ add_timestamp_token: bool = True,
33
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  dev = input_ids.device
35
  B, L = input_ids.shape
36
  cpu = torch.device("cpu")
 
44
  assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \
45
  "add_timestamp_token=True requires grids_per_temporal_window_per_batch with length B."
46
  else:
 
47
  assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \
48
  "grids_per_temporal_window_per_batch is required."
49
 
 
51
  vt_id = int(model.config.video_token_id)
52
  vs_id = getattr(model.config, "vision_start_token_id", None)
53
  ve_id = getattr(model.config, "vision_end_token_id", None)
54
+ pad_id = 151643
55
 
56
  # ---- (0+) temporal window meta ----
57
  assert video_grid_thw is not None, "video_grid_thw is required for temporal windowing"
58
  assert video_grid_thw.shape[0] == B and video_grid_thw.shape[1] == 3, \
59
  f"video_grid_thw should be ({B},3), got {video_grid_thw.shape}"
60
 
61
+ grid_area_batch: List[int] = []
62
  temporal_window_size_batch = grids_per_temporal_window_per_batch
63
 
64
  # ---- (0) Compute visual features (with grad) ----
 
67
  _vid = model.model.get_video_features(
68
  pixel_values_videos.type(model.model.visual.dtype), video_grid_thw
69
  )
70
+ video_embeds = torch.cat(_vid, dim=0) if isinstance(_vid, (list, tuple)) else _vid
71
  del pixel_values_videos, _vid
72
 
73
  # ---- (0.1) Resamplers ----
 
87
  second_resampler_num_latents = int(second_resampler.n_latents)
88
 
89
  # ---- (1) Position ids preparation ----
90
+ position_ids_full = None
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # ---- (2) Move to CPU for sequence planning ----
93
  attn_cpu = attention_mask.to(cpu, dtype=torch.bool)
94
  ids_cpu = input_ids.to(cpu)
95
+ pid_cpu = None
96
  lbls_cpu = labels.to(cpu) if labels is not None else None
97
 
98
  eff_lens: List[int] = []
99
  vid_idx_list: List[torch.Tensor] = []
100
  for b in range(B):
101
  video_grid_thw_b = video_grid_thw[b]
 
102
  grid_area = (int(video_grid_thw_b[1].item()) * int(video_grid_thw_b[2].item())) // 4
103
  grid_area_batch.append(int(grid_area))
104
 
 
113
  else:
114
  vid_idx_list.append(torch.empty(0, dtype=torch.long))
115
 
 
116
  vid_counts = [int(v.numel()) for v in vid_idx_list]
117
  vid_offsets: List[int] = [0] * B
118
  running = 0
 
122
 
123
  # ---- (3) Length planning ----
124
  def _object_block_len(b: int, obj_i: int, sel_latent_len: int, rel_temporal_window_idx: torch.Tensor) -> int:
 
 
 
 
 
 
125
  add = 0
126
 
127
  if obj_traj_start_id is not None:
128
  add += 1
129
 
 
130
  tlen = int(text_token_ids_per_sample[b][obj_i].numel())
131
  add += tlen
132
 
 
133
  if vs_id is not None:
134
  add += 1
135
 
 
136
  if add_timestamp_token and timestamp_token_ids_per_batch is not None:
137
  locs = rel_temporal_window_idx.unique()
138
  for loc in locs:
 
142
  else:
143
  add += int(timestamp_token_ids_per_batch[b][-1].numel())
144
 
 
145
  add += int(sel_latent_len)
146
 
147
  # VE
 
188
  rel = rel.to(cpu, dtype=torch.long)
189
  sel_len = int(rel.numel())
190
 
191
+ tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b]))
192
+ rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel)
193
+ nonempty_windows = int(rel_temporal_window_idx.unique().numel())
 
194
 
195
+ if use_second_resampler and second_resampler_num_latents is not None:
196
+ sel_len = int(second_resampler_num_latents) + int(resampler_num_latents) * nonempty_windows
 
 
197
  else:
198
+ sel_len = int(resampler_num_latents) * nonempty_windows
 
 
199
 
200
  cur_total += _object_block_len(b, i, sel_len, rel_temporal_window_idx)
201
 
 
213
 
214
  rows_for_video: List[torch.Tensor] = [torch.empty(0, dtype=torch.long) for _ in range(B)]
215
 
216
+ batched_obj_rows: List[torch.Tensor] = []
217
+ batched_obj_pos: List[torch.Tensor] = []
218
  batched_obj_bids: List[int] = []
219
+ batched_obj_lens: List[int] = []
220
 
221
  batched_second_rows: List[torch.Tensor] = []
222
  batched_second_pos: List[torch.Tensor] = []
 
242
 
243
  dst = 0
244
 
 
245
  if vid_idx.numel() == 0:
246
  new_input_ids_cpu[b, :L_eff] = ids_b
247
  new_attention_mask_cpu[b, :L_eff] = msk_b
248
  if new_labels_cpu is not None and labs_b is not None:
249
  new_labels_cpu[b, :L_eff] = labs_b
250
+ new_position_ids_cpu[:, b, :L_eff] = _text_pos_block(0, L_eff, dtype=torch.int32)
 
 
 
251
  continue
252
 
253
  v_s = int(vid_idx[0].item())
 
262
  prefix_len = v_s
263
  suffix_len = L_eff - (v_e + 1)
264
 
 
 
 
 
 
 
 
 
 
 
 
265
  if prefix_len > 0:
266
  new_input_ids_cpu[b, dst:dst + prefix_len] = ids_b[:prefix_len]
267
  new_attention_mask_cpu[b, dst:dst + prefix_len] = msk_b[:prefix_len]
268
  if new_labels_cpu is not None and labs_b is not None:
269
  new_labels_cpu[b, dst:dst + prefix_len] = labs_b[:prefix_len]
270
+ new_position_ids_cpu[:, b, dst:dst + prefix_len] = _text_pos_block(dst, prefix_len, dtype=torch.int32)
 
 
 
271
  dst += prefix_len
272
 
 
 
 
 
 
 
273
  Nv = int(vid_idx.numel())
274
  pos2rank = torch.full((L_eff,), -1, dtype=torch.long, device=cpu)
275
  if Nv > 0:
 
288
  # (1) <obj_traj_start> (optional)
289
  if obj_traj_start_id is not None:
290
  new_input_ids_cpu[b, dst] = int(obj_traj_start_id)
291
+ new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
292
  if new_labels_cpu is not None:
293
  new_labels_cpu[b, dst] = IGNORE_ID
294
  new_attention_mask_cpu[b, dst] = True
295
  dst += 1
 
 
296
 
297
  # (2) text tokens (required)
298
  txt_ids = text_token_ids_per_sample[b][i].to(cpu, dtype=torch.long)
299
  k = int(txt_ids.numel())
300
  if k > 0:
301
  new_input_ids_cpu[b, dst:dst + k] = txt_ids
302
+ new_position_ids_cpu[:, b, dst:dst + k] = _text_pos_block(dst, k, dtype=torch.int32)
303
  if new_labels_cpu is not None:
304
  new_labels_cpu[b, dst:dst + k] = IGNORE_ID
305
  new_attention_mask_cpu[b, dst:dst + k] = True
306
  dst += k
 
 
307
 
308
  # (3) <VS> (optional)
309
  if vs_id is not None:
310
  new_input_ids_cpu[b, dst] = int(vs_id)
311
+ new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
312
  if new_labels_cpu is not None:
313
  new_labels_cpu[b, dst] = IGNORE_ID
314
  new_attention_mask_cpu[b, dst] = True
315
  dst += 1
 
 
316
 
317
  # (4) video tokens
318
  if g.numel() > 0:
319
+ tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b]))
320
+ rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
+ W_eff = int(rel_temporal_window_idx.max().item()) + 1 if rel_temporal_window_idx.numel() > 0 else 0
323
+
324
+ all_rows_list = []
325
+ for w in range(W_eff):
326
+ m_w = (rel_temporal_window_idx == w)
327
+ if not torch.any(m_w):
328
+ all_rows_list.append(torch.empty(0, dtype=torch.long, device=cpu))
329
+ continue
330
+ rel_w = rel[m_w]
331
+ rows_w = rel_w + vid_offset
332
+ all_rows_list.append(rows_w)
333
+
334
+ # second resampler: global object summary
335
+ if use_second_resampler and second_resampler is not None:
336
+ rows_all = torch.cat([x for x in all_rows_list if x.numel() > 0], dim=0) if any(x.numel() > 0 for x in all_rows_list) \
337
+ else torch.empty(0, dtype=torch.long, device=cpu)
338
+
339
+ if rows_all.numel() > 0:
340
+ R2 = int(second_resampler_num_latents)
341
+ new_input_ids_cpu[b, dst:dst + R2] = int(vt_id)
342
+ new_position_ids_cpu[:, b, dst:dst + R2] = _text_pos_block( dst, R2, dtype=torch.int32)
343
+ if new_labels_cpu is not None:
344
+ new_labels_cpu[b, dst:dst + R2] = IGNORE_ID
345
+ new_attention_mask_cpu[b, dst:dst + R2] = True
346
+
347
+ pos_idx2 = torch.arange(dst, dst + R2, dtype=torch.long, device=cpu)
348
+ batched_second_rows.append(rows_all)
349
+ batched_second_pos.append(pos_idx2)
350
+ batched_second_bids.append(b)
351
+ batched_second_oids.append(i)
352
+
353
+ dst += R2
354
+
355
+ R = int(resampler_num_latents)
356
+
357
+ for w in range(W_eff):
358
+ m_w = (rel_temporal_window_idx == w)
359
+ if not torch.any(m_w):
360
+ continue
361
+
362
+ # timestamp tokens (text-only; NOT injected into resampler)
363
+ if add_timestamp_token and (timestamp_token_ids_per_batch is not None):
364
+ loc = w
365
+ if loc < len(timestamp_token_ids_per_batch[b]):
366
+ ts_ids = timestamp_token_ids_per_batch[b][loc].to(cpu, dtype=torch.long)
367
+ else:
368
+ ts_ids = timestamp_token_ids_per_batch[b][-1].to(cpu, dtype=torch.long)
369
+ kt = int(ts_ids.numel())
370
+ assert kt > 0, "Timestamp token ids should not be empty."
371
+
372
+ new_input_ids_cpu[b, dst:dst + kt] = ts_ids
373
+ new_position_ids_cpu[:, b, dst:dst + kt] = _text_pos_block(dst, kt, dtype=torch.int32)
374
+ if new_labels_cpu is not None:
375
+ new_labels_cpu[b, dst:dst + kt] = IGNORE_ID
376
+ new_attention_mask_cpu[b, dst:dst + kt] = True
377
+ dst += kt
378
 
 
 
 
379
 
380
+ new_input_ids_cpu[b, dst:dst + R] = int(vt_id)
381
+ new_position_ids_cpu[:, b, dst:dst + R] = _text_pos_block(dst, R, dtype=torch.int32)
 
 
 
 
 
 
 
 
 
 
 
 
382
  if new_labels_cpu is not None:
383
+ new_labels_cpu[b, dst:dst + R] = IGNORE_ID
384
+ new_attention_mask_cpu[b, dst:dst + R] = True
385
+
386
+ rel_w = rel[m_w]
387
+ rows_w = rel_w + vid_offset
388
+ pos_idx = torch.arange(dst, dst + R, dtype=torch.long, device=cpu)
389
 
390
+ batched_obj_rows.append(rows_w)
391
+ batched_obj_pos.append(pos_idx)
392
+ batched_obj_bids.append(b)
393
+ batched_obj_lens.append(int(rows_w.numel()))
394
 
395
+ dst += R
396
  # (5) <VE> (optional)
397
  if ve_id is not None:
398
  new_input_ids_cpu[b, dst] = int(ve_id)
399
+ new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
400
  if new_labels_cpu is not None:
401
  new_labels_cpu[b, dst] = IGNORE_ID
402
  new_attention_mask_cpu[b, dst] = True
403
  dst += 1
 
 
404
 
405
  # (6) <obj_traj_end> (optional)
406
  if obj_traj_end_id is not None:
407
  new_input_ids_cpu[b, dst] = int(obj_traj_end_id)
408
+ new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32)
409
  if new_labels_cpu is not None:
410
  new_labels_cpu[b, dst] = IGNORE_ID
411
  new_attention_mask_cpu[b, dst] = True
412
  dst += 1
 
 
413
 
414
  # suffix
415
  if suffix_len > 0:
 
420
  new_attention_mask_cpu[b, dst:dst + seg] = msk_b[src_lo:src_hi]
421
  if new_labels_cpu is not None and labs_b is not None:
422
  new_labels_cpu[b, dst:dst + seg] = labs_b[src_lo:src_hi]
423
+ new_position_ids_cpu[:, b, dst:dst + seg] = _text_pos_block(dst, seg, dtype=torch.int32)
424
  dst += seg
425
 
426
  assert dst == L_new_each[b], f"sample {b}: dst={dst}, L_new={L_new_each[b]}"
 
434
  base = tok_embed(new_input_ids)
435
  new_inputs_embeds = base.clone()
436
 
 
 
 
 
 
 
 
 
 
 
 
437
  # ---- (5.1) second resampler: object-level global summary ----
438
  if use_resampler and use_second_resampler and len(batched_second_rows) > 0:
439
  if video_embeds is None:
 
458
  ar2 = torch.arange(L2_max, device=dev_emb).unsqueeze(0) if L2_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long)
459
  mask2 = (ar2 < lens2_t.unsqueeze(1)) if L2_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool)
460
 
461
+ y2 = second_resampler(x2, attention_mask=mask2)
462
  y2 = y2.to(new_inputs_embeds.dtype)
463
 
464
  for j in range(N_obj2):
 
466
  pos2 = batched_second_pos[j].to(dev)
467
  new_inputs_embeds[b_cur, pos2] = y2[j]
468
 
469
+ # ---- (5.2) main resampler: temporal resampler----
470
  if use_resampler and len(batched_obj_rows) > 0:
471
  if video_embeds is None:
472
  raise RuntimeError("use_resampler=True but video_embeds is None.")
 
475
  D = video_embeds.shape[-1]
476
 
477
  N_obj = len(batched_obj_rows)
478
+ lens = torch.tensor(batched_obj_lens, device=dev_emb, dtype=torch.long)
479
  L_max = int(lens.max().item()) if lens.numel() > 0 else 0
480
 
481
  seqs = []
 
483
  if rows.numel() == 0:
484
  seqs.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb))
485
  else:
486
+ seqs.append(video_embeds.index_select(0, rows.to(dev_emb)))
487
  x = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True) if len(seqs) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb)
488
 
489
  ar = torch.arange(L_max, device=dev_emb).unsqueeze(0) if L_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long)
490
  mask = (ar < lens.unsqueeze(1)) if L_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool)
491
 
492
+ y = resampler(x, attention_mask=mask)
493
  y = y.to(new_inputs_embeds.dtype)
494
 
495
  per_b_indices: List[List[int]] = [[] for _ in range(B)]
 
509
  new_inputs_embeds[b, pos_b] = emb_b
510
 
511
  # ---- (6) rope_deltas / cache_position ----
512
+ maxpos = new_position_ids.max(dim=0)[0].max(dim=1, keepdim=True)[0]
513
  rope_deltas = (maxpos + 1 - new_inputs_embeds.shape[1]).to(dtype=torch.long, device=dev)
514
  cache_position = torch.arange(new_inputs_embeds.shape[1], device=dev, dtype=torch.int32)
515