yfan07 commited on
Commit
5f0ac48
Β·
verified Β·
1 Parent(s): c3a125e

Add files using upload-large-folder tool

Browse files
cache_q_smoke/test_s/000000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f85d7cf7b83caf6fedb153a2cea2b36dd144ee3c0e34039483e20d208ea92d3
3
+ size 2327
cache_q_smoke/test_s/index.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"sample_idx": 0, "path": "000000.pt", "vid": "-3ABOVeVmpU_136000_146000", "refs": ["the object that keeps making sound at all times"], "fids": [1], "resize": [576, 1024], "orgsize": [720, 1280], "num_seg": 1}
data/image_embed.tar CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0b0f5c8ae133bbddbfa558b2052b3aeb757492ffe310650988103d07e24135bb
3
  size 167486740480
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43e0e2002e80457512c6cdb2c171d0323335ea4bbce87ed364da22d267bb931d
3
  size 167486740480
models/avs_model.py CHANGED
@@ -270,6 +270,7 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
270
  epoch: int =0,
271
  inference: bool = False,
272
  num_frames: int = 10,
 
273
  contrast: float = 0.0,
274
 
275
  **kwargs,
@@ -282,14 +283,12 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
282
  # audio_embeddings = torch.cat(audio_features, dim=0) # [B*10, 128]
283
  # audio_embeddings = audio_features # [B, 10, 128]
284
 
285
- # train
286
- if not inference:
287
- target_frame = random.randint(0, 9)
288
  target_frame = 5
289
-
290
  else:
291
- target_frame = 5
292
- # print("target_frame", target_frame)
 
293
 
294
  input_ids, attention_masks, past_key_values, inputs_embeds, labels = super().prepare_inputs_labels_for_multimodal(
295
  input_ids, attention_masks, past_key_values=None, labels=labels, images=images_clip, audio_features=audio_embeddings, target_frame=target_frame, ref_ids=ref_ids
@@ -313,7 +312,8 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
313
  dim=1, ) # [batch_size, seq_len]
314
 
315
 
316
- seg_embeddings = self.model.text_hidden_fcs[0](output_hidden_states[-1][seg_token_mask]) # [seg_num,256]
 
317
 
318
  # print("seg_embeddings in this batch:", seg_embeddings.shape)
319
  # print("vids:", vids)
@@ -337,10 +337,12 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
337
 
338
 
339
  pred_embeddings = []
 
340
  #--------------------------------------------------------------------------------------------
341
  pred_idx = 0
342
  for ref_num in refs_num:
343
  pred_embeddings.append(seg_embeddings[pred_idx:pred_idx + ref_num])
 
344
  pred_idx += ref_num
345
  # list[B]:[num_seg, 256]
346
 
@@ -397,6 +399,7 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
397
  "pred_masks": pred_masks, # list[B]:[num_seg, T, H, W]
398
  "gt_masks": gt_masks, # list[B]:[num_seg, T, H, W]
399
  "seg_embeddings": pred_embeddings, # list[B]:[num_seg, 256]
 
400
  }
401
 
402
  model_output = output
@@ -462,4 +465,3 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
462
 
463
  def evaluate(self, *args, **kwargs):
464
  raise NotImplementedError("This method is not implemented.")
465
-
 
270
  epoch: int =0,
271
  inference: bool = False,
272
  num_frames: int = 10,
273
+ target_frame: int = None,
274
  contrast: float = 0.0,
275
 
276
  **kwargs,
 
283
  # audio_embeddings = torch.cat(audio_features, dim=0) # [B*10, 128]
284
  # audio_embeddings = audio_features # [B, 10, 128]
285
 
286
+ if target_frame is None:
 
 
287
  target_frame = 5
 
288
  else:
289
+ target_frame = int(target_frame)
290
+ if target_frame < 0 or target_frame >= num_frames:
291
+ raise ValueError(f"target_frame must be in [0, {num_frames}), got {target_frame}")
292
 
293
  input_ids, attention_masks, past_key_values, inputs_embeds, labels = super().prepare_inputs_labels_for_multimodal(
294
  input_ids, attention_masks, past_key_values=None, labels=labels, images=images_clip, audio_features=audio_embeddings, target_frame=target_frame, ref_ids=ref_ids
 
312
  dim=1, ) # [batch_size, seq_len]
313
 
314
 
315
+ seg_hidden_states = output_hidden_states[-1][seg_token_mask] # [seg_num, hidden_size]
316
+ seg_embeddings = self.model.text_hidden_fcs[0](seg_hidden_states) # [seg_num,256]
317
 
318
  # print("seg_embeddings in this batch:", seg_embeddings.shape)
319
  # print("vids:", vids)
 
337
 
338
 
339
  pred_embeddings = []
340
+ pred_hidden_states = []
341
  #--------------------------------------------------------------------------------------------
342
  pred_idx = 0
343
  for ref_num in refs_num:
344
  pred_embeddings.append(seg_embeddings[pred_idx:pred_idx + ref_num])
345
+ pred_hidden_states.append(seg_hidden_states[pred_idx:pred_idx + ref_num])
346
  pred_idx += ref_num
347
  # list[B]:[num_seg, 256]
348
 
 
399
  "pred_masks": pred_masks, # list[B]:[num_seg, T, H, W]
400
  "gt_masks": gt_masks, # list[B]:[num_seg, T, H, W]
401
  "seg_embeddings": pred_embeddings, # list[B]:[num_seg, 256]
402
+ "seg_hidden_states": pred_hidden_states, # list[B]:[num_seg, hidden_size]
403
  }
404
 
405
  model_output = output
 
465
 
466
  def evaluate(self, *args, **kwargs):
467
  raise NotImplementedError("This method is not implemented.")
 
models/segment_anything/modeling/mask_decoder.py CHANGED
@@ -140,7 +140,17 @@ class MaskDecoder(nn.Module):
140
  b, c, h, w = src.shape
141
 
142
  # Run the transformer
143
- hs, src = self.transformer(src, pos_src, tokens)
 
 
 
 
 
 
 
 
 
 
144
  iou_token_out = hs[:, 0, :]
145
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
146
 
@@ -188,7 +198,17 @@ class MaskDecoder(nn.Module):
188
  _, c, h, w = src.shape
189
 
190
  # Run the transformer
191
- hs, src = self.transformer(src, pos_src, tokens)
 
 
 
 
 
 
 
 
 
 
192
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
193
 
194
  # Upscale mask embeddings and predict masks using the mask tokens
 
140
  b, c, h, w = src.shape
141
 
142
  # Run the transformer
143
+ referent_token_index = (
144
+ 1 + self.num_mask_tokens if sparse_prompt_embeddings.shape[1] > 0 else None
145
+ )
146
+ hs, src = self.transformer(
147
+ src,
148
+ pos_src,
149
+ tokens,
150
+ mask_token_start=1,
151
+ num_mask_tokens=self.num_mask_tokens,
152
+ referent_token_index=referent_token_index,
153
+ )
154
  iou_token_out = hs[:, 0, :]
155
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
156
 
 
198
  _, c, h, w = src.shape
199
 
200
  # Run the transformer
201
+ referent_token_index = (
202
+ 1 + self.num_mask_tokens if sparse_prompt_embeddings.shape[1] > 0 else None
203
+ )
204
+ hs, src = self.transformer(
205
+ src,
206
+ pos_src,
207
+ tokens,
208
+ mask_token_start=1,
209
+ num_mask_tokens=self.num_mask_tokens,
210
+ referent_token_index=referent_token_index,
211
+ )
212
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
213
 
214
  # Upscale mask embeddings and predict masks using the mask tokens
models/segment_anything/modeling/transformer.py CHANGED
@@ -9,6 +9,7 @@ from typing import Tuple, Type
9
 
10
  import torch
11
  from torch import Tensor, nn
 
12
 
13
  from .common import MLPBlock
14
 
@@ -64,6 +65,9 @@ class TwoWayTransformer(nn.Module):
64
  image_embedding: Tensor,
65
  image_pe: Tensor,
66
  point_embedding: Tensor,
 
 
 
67
  ) -> Tuple[Tensor, Tensor]:
68
  """
69
  Args:
@@ -94,6 +98,9 @@ class TwoWayTransformer(nn.Module):
94
  keys=keys,
95
  query_pe=point_embedding,
96
  key_pe=image_pe,
 
 
 
97
  )
98
 
99
  # Apply the final attention layer from the points to the image
@@ -145,11 +152,19 @@ class TwoWayAttentionBlock(nn.Module):
145
  self.cross_attn_image_to_token = Attention(
146
  embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
  )
 
148
 
149
  self.skip_first_layer_pe = skip_first_layer_pe
150
 
151
  def forward(
152
- self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
 
 
 
 
 
 
 
153
  ) -> Tuple[Tensor, Tensor]:
154
  # Self attention block
155
  if self.skip_first_layer_pe:
@@ -160,6 +175,17 @@ class TwoWayAttentionBlock(nn.Module):
160
  queries = queries + attn_out
161
  queries = self.norm1(queries)
162
 
 
 
 
 
 
 
 
 
 
 
 
163
  # Cross attention block, tokens attending to image embedding
164
  q = queries + query_pe
165
  k = keys + key_pe
@@ -182,6 +208,26 @@ class TwoWayAttentionBlock(nn.Module):
182
  return queries, keys
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  class Attention(nn.Module):
186
  """
187
  An attention layer that allows for downscaling the size of the embedding
 
9
 
10
  import torch
11
  from torch import Tensor, nn
12
+ from torch.nn import functional as F
13
 
14
  from .common import MLPBlock
15
 
 
65
  image_embedding: Tensor,
66
  image_pe: Tensor,
67
  point_embedding: Tensor,
68
+ mask_token_start: int = None,
69
+ num_mask_tokens: int = None,
70
+ referent_token_index: int = None,
71
  ) -> Tuple[Tensor, Tensor]:
72
  """
73
  Args:
 
98
  keys=keys,
99
  query_pe=point_embedding,
100
  key_pe=image_pe,
101
+ mask_token_start=mask_token_start,
102
+ num_mask_tokens=num_mask_tokens,
103
+ referent_token_index=referent_token_index,
104
  )
105
 
106
  # Apply the final attention layer from the points to the image
 
152
  self.cross_attn_image_to_token = Attention(
153
  embedding_dim, num_heads, downsample_rate=attention_downsample_rate
154
  )
155
+ self.referent_gate = ReferentGate(embedding_dim)
156
 
157
  self.skip_first_layer_pe = skip_first_layer_pe
158
 
159
  def forward(
160
+ self,
161
+ queries: Tensor,
162
+ keys: Tensor,
163
+ query_pe: Tensor,
164
+ key_pe: Tensor,
165
+ mask_token_start: int = None,
166
+ num_mask_tokens: int = None,
167
+ referent_token_index: int = None,
168
  ) -> Tuple[Tensor, Tensor]:
169
  # Self attention block
170
  if self.skip_first_layer_pe:
 
175
  queries = queries + attn_out
176
  queries = self.norm1(queries)
177
 
178
+ if (
179
+ mask_token_start is not None
180
+ and num_mask_tokens is not None
181
+ and referent_token_index is not None
182
+ ):
183
+ mask_slice = slice(mask_token_start, mask_token_start + num_mask_tokens)
184
+ mask_tokens = queries[:, mask_slice, :]
185
+ referent_token = queries[:, referent_token_index : referent_token_index + 1, :]
186
+ queries = queries.clone()
187
+ queries[:, mask_slice, :] = self.referent_gate(mask_tokens, referent_token)
188
+
189
  # Cross attention block, tokens attending to image embedding
190
  q = queries + query_pe
191
  k = keys + key_pe
 
208
  return queries, keys
209
 
210
 
211
+ class ReferentGate(nn.Module):
212
+ def __init__(self, embedding_dim: int) -> None:
213
+ super().__init__()
214
+ self.gate = nn.Linear(embedding_dim * 2 + 1, embedding_dim)
215
+ self.proj = nn.Linear(embedding_dim, embedding_dim)
216
+ nn.init.zeros_(self.gate.weight)
217
+ nn.init.zeros_(self.gate.bias)
218
+ nn.init.zeros_(self.proj.weight)
219
+ nn.init.zeros_(self.proj.bias)
220
+ self.last_alpha = None
221
+
222
+ def forward(self, mask_tokens: Tensor, referent_token: Tensor) -> Tensor:
223
+ referent = referent_token.expand_as(mask_tokens)
224
+ cosine = F.cosine_similarity(mask_tokens, referent, dim=-1).unsqueeze(-1)
225
+ gate_input = torch.cat([mask_tokens, referent, cosine], dim=-1)
226
+ alpha = torch.sigmoid(self.gate(gate_input))
227
+ self.last_alpha = alpha.detach()
228
+ return mask_tokens + alpha * self.proj(referent)
229
+
230
+
231
  class Attention(nn.Module):
232
  """
233
  An attention layer that allows for downscaling the size of the embedding
upload_hf.py CHANGED
@@ -27,7 +27,7 @@ IGNORE_PATTERNS = [
27
  "upload.log",
28
  ]
29
 
30
- NUM_WORKERS = 2 # conservative; increase to 8 if no rate-limit errors
31
  MAX_RETRIES = 10
32
  # ───────────────────────────────────────────────────────────────────────────
33
 
 
27
  "upload.log",
28
  ]
29
 
30
+ NUM_WORKERS = 1 # conservative; increase to 8 if no rate-limit errors
31
  MAX_RETRIES = 10
32
  # ───────────────────────────────────────────────────────────────────────────
33