erow commited on
Commit
2fed9dd
·
1 Parent(s): b232299

add vjepa2

Browse files
Files changed (2) hide show
  1. app.py +331 -311
  2. examples/easy_motion_tube2.mp4 +3 -0
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  # Import spaces before torch to avoid CUDA initialization error
3
  try:
4
- import spaces
5
  except ImportError:
6
  pass
7
  import gradio as gr
@@ -11,7 +11,10 @@ import torch
11
  from PIL import Image
12
  from pathlib import Path
13
  from sklearn.decomposition import PCA
14
- from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel
 
 
 
15
  from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
  import matplotlib
17
  matplotlib.use('Agg') # Use non-interactive backend
@@ -66,11 +69,9 @@ def load_video(video_path, num_frames=16, sample_rate=4):
66
  cap.release()
67
  return frames
68
 
69
- def load_model(model_name, model_type='pretraining'):
70
- """
71
- Load model and processor by name.
72
- model_type: 'pretraining' for VideoMAEForPreTraining, 'base' for VideoMAEModel
73
- """
74
  processor = VideoMAEImageProcessor.from_pretrained(model_name)
75
  if model_type == 'base':
76
  model = VideoMAEModel.from_pretrained(model_name)
@@ -79,24 +80,8 @@ def load_model(model_name, model_type='pretraining'):
79
  model = model.to(device)
80
  return model, processor
81
 
82
- # Global model and processor
83
- model = None
84
- processor = None
85
-
86
- def initialize_model(model_name='MCG-NJU/videomae-base'):
87
- """Initialize the model (call once at startup)"""
88
- global model, processor
89
- if model is None:
90
- print(f"Loading model: {model_name}")
91
- model, processor = load_model(model_name)
92
- print("Model loaded successfully")
93
- return model, processor
94
-
95
- def visualize_attention(video_frames, model, processor, layer_idx=-1):
96
- """
97
- Visualize attention maps from VideoMAE model.
98
- Returns PIL Image for Gradio.
99
- """
100
  inputs = processor(video_frames, return_tensors="pt")
101
  pixel_values = inputs['pixel_values'].to(device)
102
  batch_size, time, num_channels, height, width = pixel_values.shape
@@ -105,13 +90,11 @@ def visualize_attention(video_frames, model, processor, layer_idx=-1):
105
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
106
  num_temporal_patches = time // tubelet_size
107
 
108
- # Use VideoMAEModel to get attention weights
109
  if hasattr(model, 'videomae'):
110
  encoder_model = model.videomae
111
  else:
112
  encoder_model = model
113
 
114
- # Disable SDPA and use eager attention
115
  original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
116
  encoder_model.config._attn_implementation = "eager"
117
 
@@ -128,7 +111,6 @@ def visualize_attention(video_frames, model, processor, layer_idx=-1):
128
  attention_weights = attentions[layer_idx][0]
129
  avg_attn = attention_weights.mean(dim=0)
130
 
131
- # Unnormalize frames
132
  dtype = pixel_values.dtype
133
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
134
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
@@ -150,14 +132,12 @@ def visualize_attention(video_frames, model, processor, layer_idx=-1):
150
  avg_attn_received = avg_attn.mean(dim=0)
151
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
152
 
153
- # Create visualization for first frame
154
  frame_idx = 0
155
  frame_img = frames_unnorm[frame_idx * tubelet_size]
156
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
157
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
158
  attn_map_upsampled = cv2.resize(attn_map, (width, height))
159
 
160
- # Create overlay
161
  fig, ax = plt.subplots(1, 1, figsize=(10, 10))
162
  ax.imshow(frame_img)
163
  ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet')
@@ -166,29 +146,37 @@ def visualize_attention(video_frames, model, processor, layer_idx=-1):
166
 
167
  return fig_to_image(fig)
168
 
169
- def visualize_latent(video_frames, model, processor):
170
- """
171
- Visualize latent space representations from VideoMAE model.
172
- Returns PIL Image for Gradio.
173
- """
174
  inputs = processor(video_frames, return_tensors="pt")
175
  pixel_values = inputs['pixel_values'].to(device)
 
 
 
 
 
176
 
177
  if hasattr(model, 'videomae'):
178
  encoder_model = model.videomae
179
  else:
180
  encoder_model = model
181
 
182
- outputs = encoder_model(pixel_values, output_hidden_states=True)
183
- hidden_states = outputs.last_hidden_state[0]
184
 
185
- batch_size, time, num_channels, height, width = pixel_values.shape
186
- tubelet_size = model.config.tubelet_size
187
- patch_size = model.config.patch_size
188
- num_patches_per_frame = (height // patch_size) * (width // patch_size)
189
- num_temporal_patches = time // tubelet_size
 
 
 
 
 
 
 
190
 
191
- # Unnormalize frames
192
  dtype = pixel_values.dtype
193
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
194
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
@@ -196,71 +184,31 @@ def visualize_latent(video_frames, model, processor):
196
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
197
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
198
 
199
- seq_len = hidden_states.shape[0]
 
 
200
  expected_seq_len = num_temporal_patches * num_patches_per_frame
201
 
202
  if seq_len != expected_seq_len:
203
  if seq_len % num_patches_per_frame == 0:
204
  num_temporal_patches = seq_len // num_patches_per_frame
205
  else:
206
- raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
207
-
208
- hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
209
- hidden_size = hidden_states_reshaped.shape[-1]
210
- hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
211
-
212
- pca = PCA(n_components=3)
213
- pca_components = pca.fit_transform(hidden_states_flat)
214
- pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
215
-
216
- H_p = int(np.sqrt(num_patches_per_frame))
217
- W_p = H_p
218
-
219
- if H_p * W_p == num_patches_per_frame:
220
- pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
221
- else:
222
- factors = []
223
- for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
224
- if num_patches_per_frame % i == 0:
225
- factors.append((i, num_patches_per_frame // i))
226
- if factors:
227
- H_p, W_p = factors[-1]
228
- pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
229
- else:
230
- raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
231
-
232
- # Normalize components
233
- for t in range(num_temporal_patches):
234
- for c in range(3):
235
- comp = pca_spatial[t, :, :, c]
236
- comp_min, comp_max = comp.min(), comp.max()
237
- if comp_max > comp_min:
238
- pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
239
- else:
240
- pca_spatial[t, :, :, c] = 0.5
241
 
242
- # Show first frame
243
- frame_idx = 0
244
- frame_img = frames_unnorm[frame_idx * tubelet_size]
245
- rgb_image = pca_spatial[frame_idx]
246
- upscale_factor = 8
247
- rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
248
 
249
- fig = plt.figure(figsize=(6,6))
250
- ax = fig.add_subplot(1, 1, 1)
251
- ax.imshow(rgb_image_upscaled)
252
- ax.set_title(f"PCA Components (RGB = PC1, PC2, PC3)")
253
- ax.axis('off')
254
- plt.suptitle(f"Explained Variance: {pca.explained_variance_ratio_.sum():.2%}", fontsize=12)
255
- plt.tight_layout()
256
 
257
- return fig_to_image(fig)
258
 
259
- def compute_reconstruction_all_frames(video_frames, model, processor):
260
- """
261
- Compute reconstruction for all frames and return as numpy arrays.
262
- Returns: (original_frames, reconstructed_frames) as numpy arrays
263
- """
264
  inputs = processor(video_frames, return_tensors="pt")
265
  T, C, H, W = inputs['pixel_values'][0].shape
266
  tubelet_size = model.config.tubelet_size
@@ -332,117 +280,113 @@ def compute_reconstruction_all_frames(video_frames, model, processor):
332
 
333
  return original_frames, reconstructed_frames_np
334
 
335
- def visualize_reconstruction(video_frames, model, processor):
336
- """
337
- Visualize reconstruction from VideoMAE model.
338
- Returns PIL Image for Gradio.
339
- """
340
  inputs = processor(video_frames, return_tensors="pt")
341
- T, C, H, W = inputs['pixel_values'][0].shape
342
- tubelet_size = model.config.tubelet_size
343
- patch_size = model.config.patch_size
344
- T = T//tubelet_size
345
 
346
- num_patches = (model.config.image_size // model.config.patch_size) ** 2
347
- num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size))
348
- total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches
349
- batch_size = inputs['pixel_values'].shape[0]
350
- bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool)
351
-
352
- for b in range(batch_size):
353
- mask_indices = np.random.choice(total_patches, num_masked, replace=False)
354
- bool_masked_pos[b, mask_indices] = True
355
-
356
- inputs['bool_masked_pos'] = bool_masked_pos.to(device)
357
- inputs['pixel_values'] = inputs['pixel_values'].to(device)
358
 
359
- outputs = model(**inputs)
360
- logits = outputs.logits
361
 
362
- pixel_values = inputs['pixel_values']
363
  batch_size, time, num_channels, height, width = pixel_values.shape
364
  tubelet_size = model.config.tubelet_size
365
  patch_size = model.config.patch_size
366
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
367
  num_temporal_patches = time // tubelet_size
368
- total_patches = num_temporal_patches * num_patches_per_frame
369
 
370
  dtype = pixel_values.dtype
371
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
372
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
373
  frames_unnorm = pixel_values * std + mean
 
 
374
 
375
- frames_patched = frames_unnorm.view(
376
- batch_size, time // tubelet_size, tubelet_size, num_channels,
377
- height // patch_size, patch_size, width // patch_size, patch_size,
378
- )
379
- frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
380
- videos_patch = frames_patched.view(
381
- batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels,
382
- )
383
 
384
- if model.config.norm_pix_loss:
385
- patch_mean = videos_patch.mean(dim=-2, keepdim=True)
386
- patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
387
- logits_denorm = logits * patch_std + patch_mean
388
- else:
389
- logits_denorm = torch.clamp(logits, 0.0, 1.0)
390
 
391
- reconstructed_patches = videos_patch.clone()
392
- reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels)
 
393
 
394
- reconstructed_patches_reshaped = reconstructed_patches.view(
395
- batch_size, time // tubelet_size, height // patch_size, width // patch_size,
396
- tubelet_size, patch_size, patch_size, num_channels,
397
- )
398
- reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous()
399
- reconstructed_frames = reconstructed_patches_reshaped.view(
400
- batch_size, time, num_channels, height, width,
401
- )
402
 
403
- original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
404
- reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy()
405
 
406
- original_frames = np.clip(original_frames, 0, 1)
407
- reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1)
 
 
 
 
 
 
 
 
 
 
408
 
409
- # Show first frame
410
- frame_idx = 0
411
- fig = plt.figure(figsize=(6,6))
412
- ax = plt.subplot(111)
413
-
414
- ax.imshow(reconstructed_frames_np[frame_idx * tubelet_size])
415
- ax.set_title(f"Reconstructed Frame: {frame_idx * tubelet_size}")
416
- ax.axis('off')
417
 
418
- return fig_to_image(fig)
 
 
 
 
 
 
 
419
 
420
- def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
421
- """
422
- Compute attention maps for all frames.
423
- Returns: (original_frames, attention_maps) as numpy arrays
424
- """
425
- inputs = processor(video_frames, return_tensors="pt")
426
- pixel_values = inputs['pixel_values'].to(device)
427
- batch_size, time, num_channels, height, width = pixel_values.shape
 
 
 
 
 
 
 
 
 
428
  tubelet_size = model.config.tubelet_size
429
  patch_size = model.config.patch_size
430
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
431
- num_temporal_patches = time // tubelet_size
432
 
433
- if hasattr(model, 'videomae'):
434
- encoder_model = model.videomae
435
- else:
436
- encoder_model = model
437
-
438
- original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
439
- encoder_model.config._attn_implementation = "eager"
440
 
441
  try:
442
- outputs = encoder_model(pixel_values, output_attentions=True)
443
  finally:
444
  if original_attn_impl is not None:
445
- encoder_model.config._attn_implementation = original_attn_impl
446
 
447
  attentions = outputs.attentions
448
  if layer_idx < 0:
@@ -451,10 +395,10 @@ def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
451
  attention_weights = attentions[layer_idx][0]
452
  avg_attn = attention_weights.mean(dim=0)
453
 
454
- dtype = pixel_values.dtype
455
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
456
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
457
- frames_unnorm = pixel_values * std + mean
458
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
459
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
460
 
@@ -472,7 +416,6 @@ def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
472
  avg_attn_received = avg_attn.mean(dim=0)
473
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
474
 
475
- # Create attention maps for all temporal patches
476
  attention_maps = []
477
  for frame_idx in range(num_temporal_patches):
478
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
@@ -482,32 +425,75 @@ def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
482
 
483
  return frames_unnorm, attention_maps
484
 
485
- def compute_latent_all_frames(video_frames, model, processor):
486
- """
487
- Compute PCA latent visualizations for all frames.
488
- Returns: (original_frames, pca_images) as numpy arrays
489
- """
490
- inputs = processor(video_frames, return_tensors="pt")
491
- pixel_values = inputs['pixel_values'].to(device)
492
 
493
- if hasattr(model, 'videomae'):
494
- encoder_model = model.videomae
495
- else:
496
- encoder_model = model
497
 
498
- outputs = encoder_model(pixel_values, output_hidden_states=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  hidden_states = outputs.last_hidden_state[0]
500
 
501
- batch_size, time, num_channels, height, width = pixel_values.shape
502
  tubelet_size = model.config.tubelet_size
503
  patch_size = model.config.patch_size
504
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
505
- num_temporal_patches = time // tubelet_size
506
 
507
- dtype = pixel_values.dtype
508
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
509
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
510
- frames_unnorm = pixel_values * std + mean
511
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
512
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
513
 
@@ -544,7 +530,6 @@ def compute_latent_all_frames(video_frames, model, processor):
544
  else:
545
  raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
546
 
547
- # Normalize components
548
  for t in range(num_temporal_patches):
549
  for c in range(3):
550
  comp = pca_spatial[t, :, :, c]
@@ -554,7 +539,6 @@ def compute_latent_all_frames(video_frames, model, processor):
554
  else:
555
  pca_spatial[t, :, :, c] = 0.5
556
 
557
- # Create upscaled images for all frames
558
  upscale_factor = 8
559
  pca_images = []
560
  for t_idx in range(num_temporal_patches):
@@ -564,81 +548,108 @@ def compute_latent_all_frames(video_frames, model, processor):
564
 
565
  return frames_unnorm, pca_images
566
 
567
- # Dummy function for backward compatibility
568
- def process_video(video_path):
569
- cap = cv2.VideoCapture(video_path)
570
- frames = []
571
- while cap.isOpened():
572
- ret, frame = cap.read()
573
- if not ret:
574
- break
575
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
576
- frames.append(frame)
577
- cap.release()
578
- visualizations = [cv2.applyColorMap((f * 0.5).astype(np.uint8), cv2.COLORMAP_JET) for f in frames]
579
- return frames, visualizations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
  # Global state to store frames after upload
582
  stored_frames = []
583
- stored_viz = []
584
-
585
- # Cache for visualization results: {video_path: {mode: {frame_idx: image}}}
586
  visualization_cache = {}
587
  current_video_path = None
588
 
589
- def on_upload(video_path, mode):
590
- global stored_frames, stored_viz, model, processor, visualization_cache, current_video_path
591
  if video_path is None:
592
  return gr.update(maximum=0), None, None
593
 
594
  # Initialize model if needed
595
- if model is None:
596
- model, processor = initialize_model()
597
 
598
- # Check if we need to recompute (new video or mode not cached)
599
  video_path_str = str(video_path)
 
600
  need_to_load_video = (video_path_str != current_video_path)
601
- need_to_compute_mode = (video_path_str not in visualization_cache or mode not in visualization_cache[video_path_str])
602
 
603
  if need_to_load_video:
604
- # Load video frames
 
 
605
  print(f"Loading video: {video_path_str}")
606
- video_frames = load_video(video_path)
607
  stored_frames = video_frames
608
  current_video_path = video_path_str
609
  else:
610
- # Reuse already loaded frames
611
  video_frames = stored_frames
612
 
613
- # Initialize cache for this video
614
- if video_path_str not in visualization_cache:
615
- visualization_cache[video_path_str] = {}
616
 
617
  if need_to_compute_mode:
618
- # Compute all visualizations and cache them
619
- print(f"Computing {mode} visualization for all frames...")
620
  num_frames = len(stored_frames)
621
  tubelet_size = model.config.tubelet_size
622
 
623
  if mode == "reconstruction":
624
- original_frames, reconstructed_frames = compute_reconstruction_all_frames(video_frames, model, processor)
625
- # Cache as images per frame - map model frames to stored frames
626
- visualization_cache[video_path_str][mode] = {}
627
- for i in range(num_frames):
628
- # Map stored frame index to model frame index
629
- model_frame_idx = min(i, len(reconstructed_frames) - 1)
630
- fig = plt.figure(figsize=(6, 6))
631
- ax = plt.subplot(111)
632
- ax.imshow(reconstructed_frames[model_frame_idx])
633
- ax.set_title(f"Reconstructed Frame: {i}")
634
- ax.axis('off')
635
- visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
 
 
 
 
 
 
 
 
 
 
636
 
637
  elif mode == "attention":
638
- original_frames, attention_maps = compute_attention_all_frames(video_frames, model, processor)
639
- visualization_cache[video_path_str][mode] = {}
 
 
 
 
640
  for i in range(num_frames):
641
- # Map stored frame to temporal patch
642
  temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1)
643
  model_frame_idx = min(i, len(original_frames) - 1)
644
  if temporal_patch_idx < len(attention_maps):
@@ -648,13 +659,16 @@ def on_upload(video_path, mode):
648
  ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet')
649
  ax.set_title(f"Attention Map - Frame {i}")
650
  ax.axis('off')
651
- visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
652
 
653
  elif mode == "latent":
654
- original_frames, pca_images = compute_latent_all_frames(video_frames, model, processor)
655
- visualization_cache[video_path_str][mode] = {}
 
 
 
 
656
  for i in range(num_frames):
657
- # Map stored frame to temporal patch
658
  temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1)
659
  if temporal_patch_idx < len(pca_images):
660
  fig = plt.figure(figsize=(6, 6))
@@ -662,34 +676,30 @@ def on_upload(video_path, mode):
662
  ax.imshow(pca_images[temporal_patch_idx])
663
  ax.set_title(f"PCA Components - Frame {i}")
664
  ax.axis('off')
665
- visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
666
 
667
- print(f"Caching complete for {mode} mode")
668
 
669
  # Load from cache
670
  max_idx = len(stored_frames) - 1
671
  frame_idx = 0
672
 
673
- # Get original frame
674
  if isinstance(stored_frames[0], Image.Image):
675
  first_frame = np.array(stored_frames[0])
676
  else:
677
  first_frame = stored_frames[0]
678
 
679
- # Get visualization from cache
680
- if video_path_str in visualization_cache and mode in visualization_cache[video_path_str]:
681
- if frame_idx in visualization_cache[video_path_str][mode]:
682
- viz_img = visualization_cache[video_path_str][mode][frame_idx]
683
  else:
684
- # Fallback if frame not in cache
685
  viz_img = Image.fromarray(first_frame)
686
  else:
687
- # Fallback if not cached
688
  viz_img = Image.fromarray(first_frame)
689
 
690
  return gr.update(maximum=max_idx, value=0), first_frame, viz_img
691
 
692
- def update_frame(idx, mode):
693
  global stored_frames, visualization_cache, current_video_path
694
  if not stored_frames:
695
  return None, None
@@ -698,49 +708,57 @@ def update_frame(idx, mode):
698
  if frame_idx >= len(stored_frames):
699
  frame_idx = len(stored_frames) - 1
700
 
701
- # Get frame
702
  if isinstance(stored_frames[frame_idx], Image.Image):
703
  frame = np.array(stored_frames[frame_idx])
704
  else:
705
  frame = stored_frames[frame_idx]
706
 
707
- # Load visualization from cache (fast!)
708
  video_path_str = current_video_path
709
- if video_path_str and video_path_str in visualization_cache:
710
- if mode in visualization_cache[video_path_str]:
711
- if frame_idx in visualization_cache[video_path_str][mode]:
712
- viz_img = visualization_cache[video_path_str][mode][frame_idx]
 
713
  else:
714
- # Fallback if frame not in cache
715
  viz_img = Image.fromarray(frame)
716
  else:
717
- # Mode not cached, return frame
718
  viz_img = Image.fromarray(frame)
719
  else:
720
- # Not cached, return frame
721
  viz_img = Image.fromarray(frame)
722
 
723
  return frame, viz_img
724
 
725
-
726
  def load_example_video(video_file):
727
- def _load_example_video( mode):
728
  """Load the predefined example video"""
729
  example_path = f"examples/{video_file}"
730
- return on_upload(example_path, mode)
731
 
732
  return _load_example_video
733
 
 
 
 
 
 
 
734
  # --- Gradio UI Layout ---
735
- with gr.Blocks(title="VideoMAE Representation Explorer") as demo:
736
- gr.Markdown("## 🎥 VideoMAE Frame-by-Frame Representation Explorer")
737
-
738
- mode_radio = gr.Radio(
739
- choices=["reconstruction", "attention", "latent"],
740
- value="reconstruction",
741
- label="Visualization Mode",
742
- info="Choose the type of visualization"
743
- )
 
 
 
 
 
 
 
744
 
745
  with gr.Row():
746
  with gr.Column():
@@ -752,47 +770,49 @@ with gr.Blocks(title="VideoMAE Representation Explorer") as demo:
752
 
753
  # Event Listeners
754
  video_lists = os.listdir("examples") if os.path.exists("examples") else []
755
- video_lists = os.listdir("app/examples") if os.path.exists("app/examples") else video_lists
756
  with gr.Row():
757
  video_input = gr.Video(label="Upload Video")
758
  with gr.Column():
759
  for video_file in video_lists:
760
  load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary")
761
- load_example_btn.click(load_example_video(video_file), inputs=mode_radio, outputs=[frame_slider, orig_output, viz_output])
 
 
 
 
762
 
763
- # load_example_btn = gr.Button("Load Example Video (dog.mp4)", variant="secondary")
764
- video_input.change(on_upload, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
 
 
 
 
 
 
 
 
 
765
 
766
- frame_slider.change(update_frame, inputs=[frame_slider, mode_radio], outputs=[orig_output, viz_output])
767
- def on_mode_change(video_path, mode):
768
- """Handle mode change - compute if not cached, otherwise use cache"""
769
- global stored_frames, model, processor, visualization_cache, current_video_path
770
  if video_path is None:
771
  return gr.update(maximum=0), None, None
772
-
773
- video_path_str = str(video_path)
774
-
775
- # If video is already loaded and mode is cached, just return cached result
776
- if video_path_str == current_video_path and video_path_str in visualization_cache:
777
- if mode in visualization_cache[video_path_str]:
778
- max_idx = len(stored_frames) - 1
779
- frame_idx = 0
780
- if isinstance(stored_frames[0], Image.Image):
781
- first_frame = np.array(stored_frames[0])
782
- else:
783
- first_frame = stored_frames[0]
784
- if frame_idx in visualization_cache[video_path_str][mode]:
785
- viz_img = visualization_cache[video_path_str][mode][frame_idx]
786
- else:
787
- viz_img = Image.fromarray(first_frame)
788
- return gr.update(maximum=max_idx, value=0), first_frame, viz_img
789
-
790
- # Otherwise, compute (will use cached video frames if available)
791
- return on_upload(video_path, mode)
792
 
793
- mode_radio.change(on_mode_change, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
 
 
 
 
794
 
795
  if __name__ == "__main__":
796
- # Initialize model at startup
797
- initialize_model()
798
- demo.launch()
 
1
  import os
2
  # Import spaces before torch to avoid CUDA initialization error
3
  try:
4
+ import spaces # pyright: ignore[reportMissingImports]
5
  except ImportError:
6
  pass
7
  import gradio as gr
 
11
  from PIL import Image
12
  from pathlib import Path
13
  from sklearn.decomposition import PCA
14
+ from transformers import (
15
+ VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel,
16
+ AutoVideoProcessor, AutoModel
17
+ )
18
  from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19
  import matplotlib
20
  matplotlib.use('Agg') # Use non-interactive backend
 
69
  cap.release()
70
  return frames
71
 
72
+ # ========== VideoMAE Functions ==========
73
+ def load_videomae_model(model_name='MCG-NJU/videomae-base', model_type='pretraining'):
74
+ """Load VideoMAE model and processor"""
 
 
75
  processor = VideoMAEImageProcessor.from_pretrained(model_name)
76
  if model_type == 'base':
77
  model = VideoMAEModel.from_pretrained(model_name)
 
80
  model = model.to(device)
81
  return model, processor
82
 
83
+ def videomae_visualize_attention(video_frames, model, processor, layer_idx=-1):
84
+ """Visualize attention maps from VideoMAE model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  inputs = processor(video_frames, return_tensors="pt")
86
  pixel_values = inputs['pixel_values'].to(device)
87
  batch_size, time, num_channels, height, width = pixel_values.shape
 
90
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
91
  num_temporal_patches = time // tubelet_size
92
 
 
93
  if hasattr(model, 'videomae'):
94
  encoder_model = model.videomae
95
  else:
96
  encoder_model = model
97
 
 
98
  original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
99
  encoder_model.config._attn_implementation = "eager"
100
 
 
111
  attention_weights = attentions[layer_idx][0]
112
  avg_attn = attention_weights.mean(dim=0)
113
 
 
114
  dtype = pixel_values.dtype
115
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
116
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
 
132
  avg_attn_received = avg_attn.mean(dim=0)
133
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
134
 
 
135
  frame_idx = 0
136
  frame_img = frames_unnorm[frame_idx * tubelet_size]
137
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
138
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
139
  attn_map_upsampled = cv2.resize(attn_map, (width, height))
140
 
 
141
  fig, ax = plt.subplots(1, 1, figsize=(10, 10))
142
  ax.imshow(frame_img)
143
  ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet')
 
146
 
147
  return fig_to_image(fig)
148
 
149
+ def videomae_compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
150
+ """Compute attention maps for all frames - VideoMAE"""
 
 
 
151
  inputs = processor(video_frames, return_tensors="pt")
152
  pixel_values = inputs['pixel_values'].to(device)
153
+ batch_size, time, num_channels, height, width = pixel_values.shape
154
+ tubelet_size = model.config.tubelet_size
155
+ patch_size = model.config.patch_size
156
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
157
+ num_temporal_patches = time // tubelet_size
158
 
159
  if hasattr(model, 'videomae'):
160
  encoder_model = model.videomae
161
  else:
162
  encoder_model = model
163
 
164
+ original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
165
+ encoder_model.config._attn_implementation = "eager"
166
 
167
+ try:
168
+ outputs = encoder_model(pixel_values, output_attentions=True)
169
+ finally:
170
+ if original_attn_impl is not None:
171
+ encoder_model.config._attn_implementation = original_attn_impl
172
+
173
+ attentions = outputs.attentions
174
+ if layer_idx < 0:
175
+ layer_idx = len(attentions) + layer_idx
176
+
177
+ attention_weights = attentions[layer_idx][0]
178
+ avg_attn = attention_weights.mean(dim=0)
179
 
 
180
  dtype = pixel_values.dtype
181
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
182
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
 
184
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
185
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
186
 
187
+ seq_len = avg_attn.shape[0]
188
+ H_p = height // patch_size
189
+ W_p = width // patch_size
190
  expected_seq_len = num_temporal_patches * num_patches_per_frame
191
 
192
  if seq_len != expected_seq_len:
193
  if seq_len % num_patches_per_frame == 0:
194
  num_temporal_patches = seq_len // num_patches_per_frame
195
  else:
196
+ raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ avg_attn_received = avg_attn.mean(dim=0)
199
+ attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
 
 
 
 
200
 
201
+ attention_maps = []
202
+ for frame_idx in range(num_temporal_patches):
203
+ attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
204
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
205
+ attn_map_upsampled = cv2.resize(attn_map, (width, height))
206
+ attention_maps.append(attn_map_upsampled)
 
207
 
208
+ return frames_unnorm, attention_maps
209
 
210
+ def videomae_compute_reconstruction_all_frames(video_frames, model, processor):
211
+ """Compute reconstruction for all frames - VideoMAE"""
 
 
 
212
  inputs = processor(video_frames, return_tensors="pt")
213
  T, C, H, W = inputs['pixel_values'][0].shape
214
  tubelet_size = model.config.tubelet_size
 
280
 
281
  return original_frames, reconstructed_frames_np
282
 
283
+ def videomae_compute_latent_all_frames(video_frames, model, processor):
284
+ """Compute PCA latent visualizations for all frames - VideoMAE"""
 
 
 
285
  inputs = processor(video_frames, return_tensors="pt")
286
+ pixel_values = inputs['pixel_values'].to(device)
 
 
 
287
 
288
+ if hasattr(model, 'videomae'):
289
+ encoder_model = model.videomae
290
+ else:
291
+ encoder_model = model
 
 
 
 
 
 
 
 
292
 
293
+ outputs = encoder_model(pixel_values, output_hidden_states=True)
294
+ hidden_states = outputs.last_hidden_state[0]
295
 
 
296
  batch_size, time, num_channels, height, width = pixel_values.shape
297
  tubelet_size = model.config.tubelet_size
298
  patch_size = model.config.patch_size
299
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
300
  num_temporal_patches = time // tubelet_size
 
301
 
302
  dtype = pixel_values.dtype
303
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
304
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
305
  frames_unnorm = pixel_values * std + mean
306
+ frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
307
+ frames_unnorm = np.clip(frames_unnorm, 0, 1)
308
 
309
+ seq_len = hidden_states.shape[0]
310
+ expected_seq_len = num_temporal_patches * num_patches_per_frame
 
 
 
 
 
 
311
 
312
+ if seq_len != expected_seq_len:
313
+ if seq_len % num_patches_per_frame == 0:
314
+ num_temporal_patches = seq_len // num_patches_per_frame
315
+ else:
316
+ raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
 
317
 
318
+ hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
319
+ hidden_size = hidden_states_reshaped.shape[-1]
320
+ hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
321
 
322
+ pca = PCA(n_components=3)
323
+ pca_components = pca.fit_transform(hidden_states_flat)
324
+ pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
 
 
 
 
 
325
 
326
+ H_p = int(np.sqrt(num_patches_per_frame))
327
+ W_p = H_p
328
 
329
+ if H_p * W_p == num_patches_per_frame:
330
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
331
+ else:
332
+ factors = []
333
+ for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
334
+ if num_patches_per_frame % i == 0:
335
+ factors.append((i, num_patches_per_frame // i))
336
+ if factors:
337
+ H_p, W_p = factors[-1]
338
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
339
+ else:
340
+ raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
341
 
342
+ for t in range(num_temporal_patches):
343
+ for c in range(3):
344
+ comp = pca_spatial[t, :, :, c]
345
+ comp_min, comp_max = comp.min(), comp.max()
346
+ if comp_max > comp_min:
347
+ pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
348
+ else:
349
+ pca_spatial[t, :, :, c] = 0.5
350
 
351
+ upscale_factor = 8
352
+ pca_images = []
353
+ for t_idx in range(num_temporal_patches):
354
+ rgb_image = pca_spatial[t_idx]
355
+ rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
356
+ pca_images.append(rgb_image_upscaled)
357
+
358
+ return frames_unnorm, pca_images
359
 
360
+ # ========== V-JEPA 2 Functions ==========
361
+ def load_vjepa2_model(model_name='facebook/vjepa2-vitl-fpc64-256'):
362
+ """Load V-JEPA 2 model and processor"""
363
+ processor = AutoVideoProcessor.from_pretrained(model_name)
364
+ model = AutoModel.from_pretrained(model_name)
365
+ model = model.to(device)
366
+ model.eval()
367
+ return model, processor
368
+
369
+ def vjepa2_compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
370
+ """Compute attention maps for all frames - V-JEPA 2"""
371
+ video_array = np.stack([np.array(frame) for frame in video_frames])
372
+
373
+ inputs = processor(video_array, return_tensors="pt")
374
+ pixel_values_videos = inputs['pixel_values_videos'].to(device)
375
+ batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
376
+
377
  tubelet_size = model.config.tubelet_size
378
  patch_size = model.config.patch_size
379
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
380
+ num_temporal_patches = num_frames // tubelet_size
381
 
382
+ original_attn_impl = getattr(model.config, '_attn_implementation', None)
383
+ model.config._attn_implementation = "eager"
 
 
 
 
 
384
 
385
  try:
386
+ outputs = model(pixel_values_videos=pixel_values_videos, output_attentions=True, skip_predictor=True)
387
  finally:
388
  if original_attn_impl is not None:
389
+ model.config._attn_implementation = original_attn_impl
390
 
391
  attentions = outputs.attentions
392
  if layer_idx < 0:
 
395
  attention_weights = attentions[layer_idx][0]
396
  avg_attn = attention_weights.mean(dim=0)
397
 
398
+ dtype = pixel_values_videos.dtype
399
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
400
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
401
+ frames_unnorm = pixel_values_videos * std + mean
402
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
403
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
404
 
 
416
  avg_attn_received = avg_attn.mean(dim=0)
417
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
418
 
 
419
  attention_maps = []
420
  for frame_idx in range(num_temporal_patches):
421
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
 
425
 
426
  return frames_unnorm, attention_maps
427
 
428
+ def vjepa2_compute_reconstruction_all_frames(video_frames, model, processor):
429
+ """Compute feature visualizations for all frames - V-JEPA 2"""
430
+ video_array = np.stack([np.array(frame) for frame in video_frames])
 
 
 
 
431
 
432
+ inputs = processor(video_array, return_tensors="pt")
433
+ pixel_values_videos = inputs['pixel_values_videos'].to(device)
434
+ batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
 
435
 
436
+ with torch.no_grad():
437
+ outputs = model(pixel_values_videos=pixel_values_videos, skip_predictor=False)
438
+
439
+ encoder_outputs = outputs.last_hidden_state[0]
440
+
441
+ tubelet_size = model.config.tubelet_size
442
+ patch_size = model.config.patch_size
443
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
444
+ num_temporal_patches = num_frames // tubelet_size
445
+
446
+ dtype = pixel_values_videos.dtype
447
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
448
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
449
+ frames_unnorm = pixel_values_videos * std + mean
450
+ frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
451
+ frames_unnorm = np.clip(frames_unnorm, 0, 1)
452
+
453
+ seq_len = encoder_outputs.shape[0]
454
+ if seq_len != num_temporal_patches * num_patches_per_frame:
455
+ if seq_len % num_patches_per_frame == 0:
456
+ num_temporal_patches = seq_len // num_patches_per_frame
457
+ else:
458
+ raise ValueError(f"Cannot reshape encoder outputs: seq_len={seq_len}")
459
+
460
+ encoder_reshaped = encoder_outputs.reshape(num_temporal_patches, num_patches_per_frame, -1)
461
+ feature_maps = encoder_reshaped[:, :, 0].detach().cpu().numpy()
462
+
463
+ H_p = height // patch_size
464
+ W_p = width // patch_size
465
+ feature_maps_spatial = feature_maps.reshape(num_temporal_patches, H_p, W_p)
466
+
467
+ feature_visualizations = []
468
+ for t_idx in range(num_temporal_patches):
469
+ feat_map = feature_maps_spatial[t_idx]
470
+ feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)
471
+ feat_map_upsampled = cv2.resize(feat_map, (width, height))
472
+ feat_rgb = np.stack([feat_map_upsampled] * 3, axis=-1)
473
+ feature_visualizations.append(feat_rgb)
474
+
475
+ return frames_unnorm, feature_visualizations
476
+
477
+ def vjepa2_compute_latent_all_frames(video_frames, model, processor):
478
+ """Compute PCA latent visualizations for all frames - V-JEPA 2"""
479
+ video_array = np.stack([np.array(frame) for frame in video_frames])
480
+
481
+ inputs = processor(video_array, return_tensors="pt")
482
+ pixel_values_videos = inputs['pixel_values_videos'].to(device)
483
+
484
+ outputs = model(pixel_values_videos=pixel_values_videos, output_hidden_states=True, skip_predictor=True)
485
  hidden_states = outputs.last_hidden_state[0]
486
 
487
+ batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
488
  tubelet_size = model.config.tubelet_size
489
  patch_size = model.config.patch_size
490
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
491
+ num_temporal_patches = num_frames // tubelet_size
492
 
493
+ dtype = pixel_values_videos.dtype
494
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
495
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
496
+ frames_unnorm = pixel_values_videos * std + mean
497
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
498
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
499
 
 
530
  else:
531
  raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
532
 
 
533
  for t in range(num_temporal_patches):
534
  for c in range(3):
535
  comp = pca_spatial[t, :, :, c]
 
539
  else:
540
  pca_spatial[t, :, :, c] = 0.5
541
 
 
542
  upscale_factor = 8
543
  pca_images = []
544
  for t_idx in range(num_temporal_patches):
 
548
 
549
  return frames_unnorm, pca_images
550
 
551
+ # ========== Unified Model Management ==========
552
+ # Global model and processor
553
+ videomae_model = None
554
+ videomae_processor = None
555
+ vjepa2_model = None
556
+ vjepa2_processor = None
557
+ current_model_type = None
558
+
559
+ def initialize_model(model_type='videomae', model_name=None):
560
+ """Initialize the selected model"""
561
+ global videomae_model, videomae_processor, vjepa2_model, vjepa2_processor, current_model_type
562
+
563
+ if model_type == 'videomae':
564
+ if videomae_model is None:
565
+ if model_name is None:
566
+ model_name = 'MCG-NJU/videomae-base'
567
+ print(f"Loading VideoMAE model: {model_name}")
568
+ videomae_model, videomae_processor = load_videomae_model(model_name)
569
+ print("VideoMAE model loaded successfully")
570
+ current_model_type = 'videomae'
571
+ return videomae_model, videomae_processor
572
+ elif model_type == 'vjepa2':
573
+ if vjepa2_model is None:
574
+ if model_name is None:
575
+ model_name = 'facebook/vjepa2-vitl-fpc64-256'
576
+ print(f"Loading V-JEPA 2 model: {model_name}")
577
+ vjepa2_model, vjepa2_processor = load_vjepa2_model(model_name)
578
+ print("V-JEPA 2 model loaded successfully")
579
+ current_model_type = 'vjepa2'
580
+ return vjepa2_model, vjepa2_processor
581
 
582
  # Global state to store frames after upload
583
  stored_frames = []
 
 
 
584
  visualization_cache = {}
585
  current_video_path = None
586
 
587
+ def on_upload(video_path, mode, model_type):
588
+ global stored_frames, visualization_cache, current_video_path
589
  if video_path is None:
590
  return gr.update(maximum=0), None, None
591
 
592
  # Initialize model if needed
593
+ model, processor = initialize_model(model_type)
 
594
 
595
+ # Check if we need to recompute (new video, mode, or model type)
596
  video_path_str = str(video_path)
597
+ cache_key = f"{video_path_str}_{model_type}"
598
  need_to_load_video = (video_path_str != current_video_path)
599
+ need_to_compute_mode = (cache_key not in visualization_cache or mode not in visualization_cache[cache_key])
600
 
601
  if need_to_load_video:
602
+ # Adjust frame count based on model type
603
+ num_frames = 64 if model_type == 'vjepa2' else 16
604
+ sample_rate = 1 if model_type == 'vjepa2' else 4
605
  print(f"Loading video: {video_path_str}")
606
+ video_frames = load_video(video_path, num_frames=num_frames, sample_rate=sample_rate)
607
  stored_frames = video_frames
608
  current_video_path = video_path_str
609
  else:
 
610
  video_frames = stored_frames
611
 
612
+ # Initialize cache for this video+model combination
613
+ if cache_key not in visualization_cache:
614
+ visualization_cache[cache_key] = {}
615
 
616
  if need_to_compute_mode:
617
+ print(f"Computing {mode} visualization for all frames using {model_type}...")
 
618
  num_frames = len(stored_frames)
619
  tubelet_size = model.config.tubelet_size
620
 
621
  if mode == "reconstruction":
622
+ if model_type == 'videomae':
623
+ original_frames, reconstructed_frames = videomae_compute_reconstruction_all_frames(video_frames, model, processor)
624
+ visualization_cache[cache_key][mode] = {}
625
+ for i in range(num_frames):
626
+ model_frame_idx = min(i, len(reconstructed_frames) - 1)
627
+ fig = plt.figure(figsize=(6, 6))
628
+ ax = plt.subplot(111)
629
+ ax.imshow(reconstructed_frames[model_frame_idx])
630
+ ax.set_title(f"Reconstructed Frame: {i}")
631
+ ax.axis('off')
632
+ visualization_cache[cache_key][mode][i] = fig_to_image(fig)
633
+ else: # vjepa2
634
+ original_frames, feature_visualizations = vjepa2_compute_reconstruction_all_frames(video_frames, model, processor)
635
+ visualization_cache[cache_key][mode] = {}
636
+ for i in range(num_frames):
637
+ temporal_patch_idx = min(i // tubelet_size, len(feature_visualizations) - 1)
638
+ fig = plt.figure(figsize=(6, 6))
639
+ ax = plt.subplot(111)
640
+ ax.imshow(feature_visualizations[temporal_patch_idx])
641
+ ax.set_title(f"Encoder Features - Frame {i}")
642
+ ax.axis('off')
643
+ visualization_cache[cache_key][mode][i] = fig_to_image(fig)
644
 
645
  elif mode == "attention":
646
+ if model_type == 'videomae':
647
+ original_frames, attention_maps = videomae_compute_attention_all_frames(video_frames, model, processor)
648
+ else: # vjepa2
649
+ original_frames, attention_maps = vjepa2_compute_attention_all_frames(video_frames, model, processor)
650
+
651
+ visualization_cache[cache_key][mode] = {}
652
  for i in range(num_frames):
 
653
  temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1)
654
  model_frame_idx = min(i, len(original_frames) - 1)
655
  if temporal_patch_idx < len(attention_maps):
 
659
  ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet')
660
  ax.set_title(f"Attention Map - Frame {i}")
661
  ax.axis('off')
662
+ visualization_cache[cache_key][mode][i] = fig_to_image(fig)
663
 
664
  elif mode == "latent":
665
+ if model_type == 'videomae':
666
+ original_frames, pca_images = videomae_compute_latent_all_frames(video_frames, model, processor)
667
+ else: # vjepa2
668
+ original_frames, pca_images = vjepa2_compute_latent_all_frames(video_frames, model, processor)
669
+
670
+ visualization_cache[cache_key][mode] = {}
671
  for i in range(num_frames):
 
672
  temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1)
673
  if temporal_patch_idx < len(pca_images):
674
  fig = plt.figure(figsize=(6, 6))
 
676
  ax.imshow(pca_images[temporal_patch_idx])
677
  ax.set_title(f"PCA Components - Frame {i}")
678
  ax.axis('off')
679
+ visualization_cache[cache_key][mode][i] = fig_to_image(fig)
680
 
681
+ print(f"Caching complete for {mode} mode using {model_type}")
682
 
683
  # Load from cache
684
  max_idx = len(stored_frames) - 1
685
  frame_idx = 0
686
 
 
687
  if isinstance(stored_frames[0], Image.Image):
688
  first_frame = np.array(stored_frames[0])
689
  else:
690
  first_frame = stored_frames[0]
691
 
692
+ if cache_key in visualization_cache and mode in visualization_cache[cache_key]:
693
+ if frame_idx in visualization_cache[cache_key][mode]:
694
+ viz_img = visualization_cache[cache_key][mode][frame_idx]
 
695
  else:
 
696
  viz_img = Image.fromarray(first_frame)
697
  else:
 
698
  viz_img = Image.fromarray(first_frame)
699
 
700
  return gr.update(maximum=max_idx, value=0), first_frame, viz_img
701
 
702
+ def update_frame(idx, mode, model_type):
703
  global stored_frames, visualization_cache, current_video_path
704
  if not stored_frames:
705
  return None, None
 
708
  if frame_idx >= len(stored_frames):
709
  frame_idx = len(stored_frames) - 1
710
 
 
711
  if isinstance(stored_frames[frame_idx], Image.Image):
712
  frame = np.array(stored_frames[frame_idx])
713
  else:
714
  frame = stored_frames[frame_idx]
715
 
 
716
  video_path_str = current_video_path
717
+ cache_key = f"{video_path_str}_{model_type}" if video_path_str else None
718
+ if cache_key and cache_key in visualization_cache:
719
+ if mode in visualization_cache[cache_key]:
720
+ if frame_idx in visualization_cache[cache_key][mode]:
721
+ viz_img = visualization_cache[cache_key][mode][frame_idx]
722
  else:
 
723
  viz_img = Image.fromarray(frame)
724
  else:
 
725
  viz_img = Image.fromarray(frame)
726
  else:
 
727
  viz_img = Image.fromarray(frame)
728
 
729
  return frame, viz_img
730
 
 
731
  def load_example_video(video_file):
732
+ def _load_example_video(mode, model_type):
733
  """Load the predefined example video"""
734
  example_path = f"examples/{video_file}"
735
+ return on_upload(example_path, mode, model_type)
736
 
737
  return _load_example_video
738
 
739
+ def on_model_change(model_type, video_path, mode):
740
+ """Handle model type change"""
741
+ if video_path is None:
742
+ return gr.update(maximum=0), None, None
743
+ return on_upload(video_path, mode, model_type)
744
+
745
  # --- Gradio UI Layout ---
746
+ with gr.Blocks(title="Video Representation Explorer") as demo:
747
+ gr.Markdown("## 🎥 Video Representation Explorer - VideoMAE & V-JEPA 2")
748
+
749
+ with gr.Row():
750
+ model_radio = gr.Radio(
751
+ choices=["videomae", "vjepa2"],
752
+ value="videomae",
753
+ label="Model Type",
754
+ info="Choose between VideoMAE and V-JEPA 2"
755
+ )
756
+ mode_radio = gr.Radio(
757
+ choices=["reconstruction", "attention", "latent"],
758
+ value="reconstruction",
759
+ label="Visualization Mode",
760
+ info="Choose the type of visualization"
761
+ )
762
 
763
  with gr.Row():
764
  with gr.Column():
 
770
 
771
  # Event Listeners
772
  video_lists = os.listdir("examples") if os.path.exists("examples") else []
773
+ video_lists = os.listdir("app/examples") if os.path.exists("app/examples") else video_lists
774
  with gr.Row():
775
  video_input = gr.Video(label="Upload Video")
776
  with gr.Column():
777
  for video_file in video_lists:
778
  load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary")
779
+ load_example_btn.click(
780
+ load_example_video(video_file),
781
+ inputs=[mode_radio, model_radio],
782
+ outputs=[frame_slider, orig_output, viz_output]
783
+ )
784
 
785
+ video_input.change(
786
+ on_upload,
787
+ inputs=[video_input, mode_radio, model_radio],
788
+ outputs=[frame_slider, orig_output, viz_output]
789
+ )
790
+
791
+ frame_slider.change(
792
+ update_frame,
793
+ inputs=[frame_slider, mode_radio, model_radio],
794
+ outputs=[orig_output, viz_output]
795
+ )
796
 
797
+ def on_mode_change(video_path, mode, model_type):
798
+ """Handle mode change"""
 
 
799
  if video_path is None:
800
  return gr.update(maximum=0), None, None
801
+ return on_upload(video_path, mode, model_type)
802
+
803
+ mode_radio.change(
804
+ on_mode_change,
805
+ inputs=[video_input, mode_radio, model_radio],
806
+ outputs=[frame_slider, orig_output, viz_output]
807
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
808
 
809
+ model_radio.change(
810
+ on_model_change,
811
+ inputs=[model_radio, video_input, mode_radio],
812
+ outputs=[frame_slider, orig_output, viz_output]
813
+ )
814
 
815
  if __name__ == "__main__":
816
+ # Initialize default model at startup
817
+ initialize_model('videomae')
818
+ demo.launch(share=True)
examples/easy_motion_tube2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f38d7badf4d90059fa50c905daec70a107e9bfd225d7fba4e89416f20927656
3
+ size 49261