erow commited on
Commit
9adc4ed
·
1 Parent(s): f68420d
Files changed (1) hide show
  1. app.py +311 -331
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  # Import spaces before torch to avoid CUDA initialization error
3
  try:
4
- import spaces # pyright: ignore[reportMissingImports]
5
  except ImportError:
6
  pass
7
  import gradio as gr
@@ -11,10 +11,7 @@ import torch
11
  from PIL import Image
12
  from pathlib import Path
13
  from sklearn.decomposition import PCA
14
- from transformers import (
15
- VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel,
16
- AutoVideoProcessor, AutoModel
17
- )
18
  from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19
  import matplotlib
20
  matplotlib.use('Agg') # Use non-interactive backend
@@ -69,9 +66,11 @@ def load_video(video_path, num_frames=16, sample_rate=4):
69
  cap.release()
70
  return frames
71
 
72
- # ========== VideoMAE Functions ==========
73
- def load_videomae_model(model_name='MCG-NJU/videomae-base', model_type='pretraining'):
74
- """Load VideoMAE model and processor"""
 
 
75
  processor = VideoMAEImageProcessor.from_pretrained(model_name)
76
  if model_type == 'base':
77
  model = VideoMAEModel.from_pretrained(model_name)
@@ -80,8 +79,24 @@ def load_videomae_model(model_name='MCG-NJU/videomae-base', model_type='pretrain
80
  model = model.to(device)
81
  return model, processor
82
 
83
- def videomae_visualize_attention(video_frames, model, processor, layer_idx=-1):
84
- """Visualize attention maps from VideoMAE model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  inputs = processor(video_frames, return_tensors="pt")
86
  pixel_values = inputs['pixel_values'].to(device)
87
  batch_size, time, num_channels, height, width = pixel_values.shape
@@ -90,11 +105,13 @@ def videomae_visualize_attention(video_frames, model, processor, layer_idx=-1):
90
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
91
  num_temporal_patches = time // tubelet_size
92
 
 
93
  if hasattr(model, 'videomae'):
94
  encoder_model = model.videomae
95
  else:
96
  encoder_model = model
97
 
 
98
  original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
99
  encoder_model.config._attn_implementation = "eager"
100
 
@@ -111,6 +128,7 @@ def videomae_visualize_attention(video_frames, model, processor, layer_idx=-1):
111
  attention_weights = attentions[layer_idx][0]
112
  avg_attn = attention_weights.mean(dim=0)
113
 
 
114
  dtype = pixel_values.dtype
115
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
116
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
@@ -132,12 +150,14 @@ def videomae_visualize_attention(video_frames, model, processor, layer_idx=-1):
132
  avg_attn_received = avg_attn.mean(dim=0)
133
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
134
 
 
135
  frame_idx = 0
136
  frame_img = frames_unnorm[frame_idx * tubelet_size]
137
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
138
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
139
  attn_map_upsampled = cv2.resize(attn_map, (width, height))
140
 
 
141
  fig, ax = plt.subplots(1, 1, figsize=(10, 10))
142
  ax.imshow(frame_img)
143
  ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet')
@@ -146,37 +166,29 @@ def videomae_visualize_attention(video_frames, model, processor, layer_idx=-1):
146
 
147
  return fig_to_image(fig)
148
 
149
- def videomae_compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
150
- """Compute attention maps for all frames - VideoMAE"""
 
 
 
151
  inputs = processor(video_frames, return_tensors="pt")
152
  pixel_values = inputs['pixel_values'].to(device)
153
- batch_size, time, num_channels, height, width = pixel_values.shape
154
- tubelet_size = model.config.tubelet_size
155
- patch_size = model.config.patch_size
156
- num_patches_per_frame = (height // patch_size) * (width // patch_size)
157
- num_temporal_patches = time // tubelet_size
158
 
159
  if hasattr(model, 'videomae'):
160
  encoder_model = model.videomae
161
  else:
162
  encoder_model = model
163
 
164
- original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
165
- encoder_model.config._attn_implementation = "eager"
166
-
167
- try:
168
- outputs = encoder_model(pixel_values, output_attentions=True)
169
- finally:
170
- if original_attn_impl is not None:
171
- encoder_model.config._attn_implementation = original_attn_impl
172
-
173
- attentions = outputs.attentions
174
- if layer_idx < 0:
175
- layer_idx = len(attentions) + layer_idx
176
 
177
- attention_weights = attentions[layer_idx][0]
178
- avg_attn = attention_weights.mean(dim=0)
 
 
 
179
 
 
180
  dtype = pixel_values.dtype
181
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
182
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
@@ -184,31 +196,71 @@ def videomae_compute_attention_all_frames(video_frames, model, processor, layer_
184
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
185
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
186
 
187
- seq_len = avg_attn.shape[0]
188
- H_p = height // patch_size
189
- W_p = width // patch_size
190
  expected_seq_len = num_temporal_patches * num_patches_per_frame
191
 
192
  if seq_len != expected_seq_len:
193
  if seq_len % num_patches_per_frame == 0:
194
  num_temporal_patches = seq_len // num_patches_per_frame
195
  else:
196
- raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}")
197
 
198
- avg_attn_received = avg_attn.mean(dim=0)
199
- attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
 
200
 
201
- attention_maps = []
202
- for frame_idx in range(num_temporal_patches):
203
- attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
204
- attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
205
- attn_map_upsampled = cv2.resize(attn_map, (width, height))
206
- attention_maps.append(attn_map_upsampled)
207
 
208
- return frames_unnorm, attention_maps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- def videomae_compute_reconstruction_all_frames(video_frames, model, processor):
211
- """Compute reconstruction for all frames - VideoMAE"""
 
 
 
212
  inputs = processor(video_frames, return_tensors="pt")
213
  T, C, H, W = inputs['pixel_values'][0].shape
214
  tubelet_size = model.config.tubelet_size
@@ -280,113 +332,117 @@ def videomae_compute_reconstruction_all_frames(video_frames, model, processor):
280
 
281
  return original_frames, reconstructed_frames_np
282
 
283
- def videomae_compute_latent_all_frames(video_frames, model, processor):
284
- """Compute PCA latent visualizations for all frames - VideoMAE"""
 
 
 
285
  inputs = processor(video_frames, return_tensors="pt")
286
- pixel_values = inputs['pixel_values'].to(device)
 
 
 
287
 
288
- if hasattr(model, 'videomae'):
289
- encoder_model = model.videomae
290
- else:
291
- encoder_model = model
 
 
 
 
 
 
 
 
292
 
293
- outputs = encoder_model(pixel_values, output_hidden_states=True)
294
- hidden_states = outputs.last_hidden_state[0]
295
 
 
296
  batch_size, time, num_channels, height, width = pixel_values.shape
297
  tubelet_size = model.config.tubelet_size
298
  patch_size = model.config.patch_size
299
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
300
  num_temporal_patches = time // tubelet_size
 
301
 
302
  dtype = pixel_values.dtype
303
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
304
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
305
  frames_unnorm = pixel_values * std + mean
306
- frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
307
- frames_unnorm = np.clip(frames_unnorm, 0, 1)
308
-
309
- seq_len = hidden_states.shape[0]
310
- expected_seq_len = num_temporal_patches * num_patches_per_frame
311
-
312
- if seq_len != expected_seq_len:
313
- if seq_len % num_patches_per_frame == 0:
314
- num_temporal_patches = seq_len // num_patches_per_frame
315
- else:
316
- raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
317
 
318
- hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
319
- hidden_size = hidden_states_reshaped.shape[-1]
320
- hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
 
 
 
 
 
321
 
322
- pca = PCA(n_components=3)
323
- pca_components = pca.fit_transform(hidden_states_flat)
324
- pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
 
 
 
325
 
326
- H_p = int(np.sqrt(num_patches_per_frame))
327
- W_p = H_p
328
 
329
- if H_p * W_p == num_patches_per_frame:
330
- pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
331
- else:
332
- factors = []
333
- for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
334
- if num_patches_per_frame % i == 0:
335
- factors.append((i, num_patches_per_frame // i))
336
- if factors:
337
- H_p, W_p = factors[-1]
338
- pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
339
- else:
340
- raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
341
 
342
- for t in range(num_temporal_patches):
343
- for c in range(3):
344
- comp = pca_spatial[t, :, :, c]
345
- comp_min, comp_max = comp.min(), comp.max()
346
- if comp_max > comp_min:
347
- pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
348
- else:
349
- pca_spatial[t, :, :, c] = 0.5
350
 
351
- upscale_factor = 8
352
- pca_images = []
353
- for t_idx in range(num_temporal_patches):
354
- rgb_image = pca_spatial[t_idx]
355
- rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
356
- pca_images.append(rgb_image_upscaled)
357
 
358
- return frames_unnorm, pca_images
359
-
360
- # ========== V-JEPA 2 Functions ==========
361
- def load_vjepa2_model(model_name='facebook/vjepa2-vitl-fpc64-256'):
362
- """Load V-JEPA 2 model and processor"""
363
- processor = AutoVideoProcessor.from_pretrained(model_name)
364
- model = AutoModel.from_pretrained(model_name)
365
- model = model.to(device)
366
- model.eval()
367
- return model, processor
368
 
369
- def vjepa2_compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
370
- """Compute attention maps for all frames - V-JEPA 2"""
371
- video_array = np.stack([np.array(frame) for frame in video_frames])
372
-
373
- inputs = processor(video_array, return_tensors="pt")
374
- pixel_values_videos = inputs['pixel_values_videos'].to(device)
375
- batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
376
 
 
 
 
 
 
 
 
 
 
 
377
  tubelet_size = model.config.tubelet_size
378
  patch_size = model.config.patch_size
379
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
380
- num_temporal_patches = num_frames // tubelet_size
381
 
382
- original_attn_impl = getattr(model.config, '_attn_implementation', None)
383
- model.config._attn_implementation = "eager"
 
 
 
 
 
384
 
385
  try:
386
- outputs = model(pixel_values_videos=pixel_values_videos, output_attentions=True, skip_predictor=True)
387
  finally:
388
  if original_attn_impl is not None:
389
- model.config._attn_implementation = original_attn_impl
390
 
391
  attentions = outputs.attentions
392
  if layer_idx < 0:
@@ -395,10 +451,10 @@ def vjepa2_compute_attention_all_frames(video_frames, model, processor, layer_id
395
  attention_weights = attentions[layer_idx][0]
396
  avg_attn = attention_weights.mean(dim=0)
397
 
398
- dtype = pixel_values_videos.dtype
399
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
400
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
401
- frames_unnorm = pixel_values_videos * std + mean
402
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
403
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
404
 
@@ -416,6 +472,7 @@ def vjepa2_compute_attention_all_frames(video_frames, model, processor, layer_id
416
  avg_attn_received = avg_attn.mean(dim=0)
417
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
418
 
 
419
  attention_maps = []
420
  for frame_idx in range(num_temporal_patches):
421
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
@@ -425,75 +482,32 @@ def vjepa2_compute_attention_all_frames(video_frames, model, processor, layer_id
425
 
426
  return frames_unnorm, attention_maps
427
 
428
- def vjepa2_compute_reconstruction_all_frames(video_frames, model, processor):
429
- """Compute feature visualizations for all frames - V-JEPA 2"""
430
- video_array = np.stack([np.array(frame) for frame in video_frames])
431
-
432
- inputs = processor(video_array, return_tensors="pt")
433
- pixel_values_videos = inputs['pixel_values_videos'].to(device)
434
- batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
435
-
436
- with torch.no_grad():
437
- outputs = model(pixel_values_videos=pixel_values_videos, skip_predictor=False)
438
-
439
- encoder_outputs = outputs.last_hidden_state[0]
440
-
441
- tubelet_size = model.config.tubelet_size
442
- patch_size = model.config.patch_size
443
- num_patches_per_frame = (height // patch_size) * (width // patch_size)
444
- num_temporal_patches = num_frames // tubelet_size
445
-
446
- dtype = pixel_values_videos.dtype
447
- mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
448
- std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
449
- frames_unnorm = pixel_values_videos * std + mean
450
- frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
451
- frames_unnorm = np.clip(frames_unnorm, 0, 1)
452
-
453
- seq_len = encoder_outputs.shape[0]
454
- if seq_len != num_temporal_patches * num_patches_per_frame:
455
- if seq_len % num_patches_per_frame == 0:
456
- num_temporal_patches = seq_len // num_patches_per_frame
457
- else:
458
- raise ValueError(f"Cannot reshape encoder outputs: seq_len={seq_len}")
459
-
460
- encoder_reshaped = encoder_outputs.reshape(num_temporal_patches, num_patches_per_frame, -1)
461
- feature_maps = encoder_reshaped[:, :, 0].detach().cpu().numpy()
462
-
463
- H_p = height // patch_size
464
- W_p = width // patch_size
465
- feature_maps_spatial = feature_maps.reshape(num_temporal_patches, H_p, W_p)
466
-
467
- feature_visualizations = []
468
- for t_idx in range(num_temporal_patches):
469
- feat_map = feature_maps_spatial[t_idx]
470
- feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)
471
- feat_map_upsampled = cv2.resize(feat_map, (width, height))
472
- feat_rgb = np.stack([feat_map_upsampled] * 3, axis=-1)
473
- feature_visualizations.append(feat_rgb)
474
-
475
- return frames_unnorm, feature_visualizations
476
-
477
- def vjepa2_compute_latent_all_frames(video_frames, model, processor):
478
- """Compute PCA latent visualizations for all frames - V-JEPA 2"""
479
- video_array = np.stack([np.array(frame) for frame in video_frames])
480
 
481
- inputs = processor(video_array, return_tensors="pt")
482
- pixel_values_videos = inputs['pixel_values_videos'].to(device)
 
 
483
 
484
- outputs = model(pixel_values_videos=pixel_values_videos, output_hidden_states=True, skip_predictor=True)
485
  hidden_states = outputs.last_hidden_state[0]
486
 
487
- batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
488
  tubelet_size = model.config.tubelet_size
489
  patch_size = model.config.patch_size
490
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
491
- num_temporal_patches = num_frames // tubelet_size
492
 
493
- dtype = pixel_values_videos.dtype
494
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
495
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
496
- frames_unnorm = pixel_values_videos * std + mean
497
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
498
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
499
 
@@ -530,6 +544,7 @@ def vjepa2_compute_latent_all_frames(video_frames, model, processor):
530
  else:
531
  raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
532
 
 
533
  for t in range(num_temporal_patches):
534
  for c in range(3):
535
  comp = pca_spatial[t, :, :, c]
@@ -539,6 +554,7 @@ def vjepa2_compute_latent_all_frames(video_frames, model, processor):
539
  else:
540
  pca_spatial[t, :, :, c] = 0.5
541
 
 
542
  upscale_factor = 8
543
  pca_images = []
544
  for t_idx in range(num_temporal_patches):
@@ -548,108 +564,81 @@ def vjepa2_compute_latent_all_frames(video_frames, model, processor):
548
 
549
  return frames_unnorm, pca_images
550
 
551
- # ========== Unified Model Management ==========
552
- # Global model and processor
553
- videomae_model = None
554
- videomae_processor = None
555
- vjepa2_model = None
556
- vjepa2_processor = None
557
- current_model_type = None
558
-
559
- def initialize_model(model_type='videomae', model_name=None):
560
- """Initialize the selected model"""
561
- global videomae_model, videomae_processor, vjepa2_model, vjepa2_processor, current_model_type
562
-
563
- if model_type == 'videomae':
564
- if videomae_model is None:
565
- if model_name is None:
566
- model_name = 'MCG-NJU/videomae-base'
567
- print(f"Loading VideoMAE model: {model_name}")
568
- videomae_model, videomae_processor = load_videomae_model(model_name)
569
- print("VideoMAE model loaded successfully")
570
- current_model_type = 'videomae'
571
- return videomae_model, videomae_processor
572
- elif model_type == 'vjepa2':
573
- if vjepa2_model is None:
574
- if model_name is None:
575
- model_name = 'facebook/vjepa2-vitl-fpc64-256'
576
- print(f"Loading V-JEPA 2 model: {model_name}")
577
- vjepa2_model, vjepa2_processor = load_vjepa2_model(model_name)
578
- print("V-JEPA 2 model loaded successfully")
579
- current_model_type = 'vjepa2'
580
- return vjepa2_model, vjepa2_processor
581
 
582
  # Global state to store frames after upload
583
  stored_frames = []
 
 
 
584
  visualization_cache = {}
585
  current_video_path = None
586
 
587
- def on_upload(video_path, mode, model_type):
588
- global stored_frames, visualization_cache, current_video_path
589
  if video_path is None:
590
  return gr.update(maximum=0), None, None
591
 
592
  # Initialize model if needed
593
- model, processor = initialize_model(model_type)
 
594
 
595
- # Check if we need to recompute (new video, mode, or model type)
596
  video_path_str = str(video_path)
597
- cache_key = f"{video_path_str}_{model_type}"
598
  need_to_load_video = (video_path_str != current_video_path)
599
- need_to_compute_mode = (cache_key not in visualization_cache or mode not in visualization_cache[cache_key])
600
 
601
  if need_to_load_video:
602
- # Adjust frame count based on model type
603
- num_frames = 64 if model_type == 'vjepa2' else 16
604
- sample_rate = 1 if model_type == 'vjepa2' else 4
605
  print(f"Loading video: {video_path_str}")
606
- video_frames = load_video(video_path, num_frames=num_frames, sample_rate=sample_rate)
607
  stored_frames = video_frames
608
  current_video_path = video_path_str
609
  else:
 
610
  video_frames = stored_frames
611
 
612
- # Initialize cache for this video+model combination
613
- if cache_key not in visualization_cache:
614
- visualization_cache[cache_key] = {}
615
 
616
  if need_to_compute_mode:
617
- print(f"Computing {mode} visualization for all frames using {model_type}...")
 
618
  num_frames = len(stored_frames)
619
  tubelet_size = model.config.tubelet_size
620
 
621
  if mode == "reconstruction":
622
- if model_type == 'videomae':
623
- original_frames, reconstructed_frames = videomae_compute_reconstruction_all_frames(video_frames, model, processor)
624
- visualization_cache[cache_key][mode] = {}
625
- for i in range(num_frames):
626
- model_frame_idx = min(i, len(reconstructed_frames) - 1)
627
- fig = plt.figure(figsize=(6, 6))
628
- ax = plt.subplot(111)
629
- ax.imshow(reconstructed_frames[model_frame_idx])
630
- ax.set_title(f"Reconstructed Frame: {i}")
631
- ax.axis('off')
632
- visualization_cache[cache_key][mode][i] = fig_to_image(fig)
633
- else: # vjepa2
634
- original_frames, feature_visualizations = vjepa2_compute_reconstruction_all_frames(video_frames, model, processor)
635
- visualization_cache[cache_key][mode] = {}
636
- for i in range(num_frames):
637
- temporal_patch_idx = min(i // tubelet_size, len(feature_visualizations) - 1)
638
- fig = plt.figure(figsize=(6, 6))
639
- ax = plt.subplot(111)
640
- ax.imshow(feature_visualizations[temporal_patch_idx])
641
- ax.set_title(f"Encoder Features - Frame {i}")
642
- ax.axis('off')
643
- visualization_cache[cache_key][mode][i] = fig_to_image(fig)
644
 
645
  elif mode == "attention":
646
- if model_type == 'videomae':
647
- original_frames, attention_maps = videomae_compute_attention_all_frames(video_frames, model, processor)
648
- else: # vjepa2
649
- original_frames, attention_maps = vjepa2_compute_attention_all_frames(video_frames, model, processor)
650
-
651
- visualization_cache[cache_key][mode] = {}
652
  for i in range(num_frames):
 
653
  temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1)
654
  model_frame_idx = min(i, len(original_frames) - 1)
655
  if temporal_patch_idx < len(attention_maps):
@@ -659,16 +648,13 @@ def on_upload(video_path, mode, model_type):
659
  ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet')
660
  ax.set_title(f"Attention Map - Frame {i}")
661
  ax.axis('off')
662
- visualization_cache[cache_key][mode][i] = fig_to_image(fig)
663
 
664
  elif mode == "latent":
665
- if model_type == 'videomae':
666
- original_frames, pca_images = videomae_compute_latent_all_frames(video_frames, model, processor)
667
- else: # vjepa2
668
- original_frames, pca_images = vjepa2_compute_latent_all_frames(video_frames, model, processor)
669
-
670
- visualization_cache[cache_key][mode] = {}
671
  for i in range(num_frames):
 
672
  temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1)
673
  if temporal_patch_idx < len(pca_images):
674
  fig = plt.figure(figsize=(6, 6))
@@ -676,30 +662,34 @@ def on_upload(video_path, mode, model_type):
676
  ax.imshow(pca_images[temporal_patch_idx])
677
  ax.set_title(f"PCA Components - Frame {i}")
678
  ax.axis('off')
679
- visualization_cache[cache_key][mode][i] = fig_to_image(fig)
680
 
681
- print(f"Caching complete for {mode} mode using {model_type}")
682
 
683
  # Load from cache
684
  max_idx = len(stored_frames) - 1
685
  frame_idx = 0
686
 
 
687
  if isinstance(stored_frames[0], Image.Image):
688
  first_frame = np.array(stored_frames[0])
689
  else:
690
  first_frame = stored_frames[0]
691
 
692
- if cache_key in visualization_cache and mode in visualization_cache[cache_key]:
693
- if frame_idx in visualization_cache[cache_key][mode]:
694
- viz_img = visualization_cache[cache_key][mode][frame_idx]
 
695
  else:
 
696
  viz_img = Image.fromarray(first_frame)
697
  else:
 
698
  viz_img = Image.fromarray(first_frame)
699
 
700
  return gr.update(maximum=max_idx, value=0), first_frame, viz_img
701
 
702
- def update_frame(idx, mode, model_type):
703
  global stored_frames, visualization_cache, current_video_path
704
  if not stored_frames:
705
  return None, None
@@ -708,57 +698,49 @@ def update_frame(idx, mode, model_type):
708
  if frame_idx >= len(stored_frames):
709
  frame_idx = len(stored_frames) - 1
710
 
 
711
  if isinstance(stored_frames[frame_idx], Image.Image):
712
  frame = np.array(stored_frames[frame_idx])
713
  else:
714
  frame = stored_frames[frame_idx]
715
 
 
716
  video_path_str = current_video_path
717
- cache_key = f"{video_path_str}_{model_type}" if video_path_str else None
718
- if cache_key and cache_key in visualization_cache:
719
- if mode in visualization_cache[cache_key]:
720
- if frame_idx in visualization_cache[cache_key][mode]:
721
- viz_img = visualization_cache[cache_key][mode][frame_idx]
722
  else:
 
723
  viz_img = Image.fromarray(frame)
724
  else:
 
725
  viz_img = Image.fromarray(frame)
726
  else:
 
727
  viz_img = Image.fromarray(frame)
728
 
729
  return frame, viz_img
730
 
 
731
  def load_example_video(video_file):
732
- def _load_example_video(mode, model_type):
733
  """Load the predefined example video"""
734
  example_path = f"examples/{video_file}"
735
- return on_upload(example_path, mode, model_type)
736
 
737
  return _load_example_video
738
 
739
- def on_model_change(model_type, video_path, mode):
740
- """Handle model type change"""
741
- if video_path is None:
742
- return gr.update(maximum=0), None, None
743
- return on_upload(video_path, mode, model_type)
744
-
745
  # --- Gradio UI Layout ---
746
- with gr.Blocks(title="Video Representation Explorer") as demo:
747
- gr.Markdown("## 🎥 Video Representation Explorer - VideoMAE & V-JEPA 2")
748
-
749
- with gr.Row():
750
- model_radio = gr.Radio(
751
- choices=["videomae", "vjepa2"],
752
- value="videomae",
753
- label="Model Type",
754
- info="Choose between VideoMAE and V-JEPA 2"
755
- )
756
- mode_radio = gr.Radio(
757
- choices=["reconstruction", "attention", "latent"],
758
- value="reconstruction",
759
- label="Visualization Mode",
760
- info="Choose the type of visualization"
761
- )
762
 
763
  with gr.Row():
764
  with gr.Column():
@@ -770,49 +752,47 @@ with gr.Blocks(title="Video Representation Explorer") as demo:
770
 
771
  # Event Listeners
772
  video_lists = os.listdir("examples") if os.path.exists("examples") else []
773
- video_lists = os.listdir("app/examples") if os.path.exists("app/examples") else video_lists
774
  with gr.Row():
775
  video_input = gr.Video(label="Upload Video")
776
  with gr.Column():
777
  for video_file in video_lists:
778
  load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary")
779
- load_example_btn.click(
780
- load_example_video(video_file),
781
- inputs=[mode_radio, model_radio],
782
- outputs=[frame_slider, orig_output, viz_output]
783
- )
784
 
785
- video_input.change(
786
- on_upload,
787
- inputs=[video_input, mode_radio, model_radio],
788
- outputs=[frame_slider, orig_output, viz_output]
789
- )
790
-
791
- frame_slider.change(
792
- update_frame,
793
- inputs=[frame_slider, mode_radio, model_radio],
794
- outputs=[orig_output, viz_output]
795
- )
796
 
797
- def on_mode_change(video_path, mode, model_type):
798
- """Handle mode change"""
 
 
799
  if video_path is None:
800
  return gr.update(maximum=0), None, None
801
- return on_upload(video_path, mode, model_type)
802
-
803
- mode_radio.change(
804
- on_mode_change,
805
- inputs=[video_input, mode_radio, model_radio],
806
- outputs=[frame_slider, orig_output, viz_output]
807
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
808
 
809
- model_radio.change(
810
- on_model_change,
811
- inputs=[model_radio, video_input, mode_radio],
812
- outputs=[frame_slider, orig_output, viz_output]
813
- )
814
 
815
  if __name__ == "__main__":
816
- # Initialize default model at startup
817
- initialize_model('videomae')
818
- demo.launch()
 
1
  import os
2
  # Import spaces before torch to avoid CUDA initialization error
3
  try:
4
+ import spaces
5
  except ImportError:
6
  pass
7
  import gradio as gr
 
11
  from PIL import Image
12
  from pathlib import Path
13
  from sklearn.decomposition import PCA
14
+ from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel
 
 
 
15
  from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
  import matplotlib
17
  matplotlib.use('Agg') # Use non-interactive backend
 
66
  cap.release()
67
  return frames
68
 
69
+ def load_model(model_name, model_type='pretraining'):
70
+ """
71
+ Load model and processor by name.
72
+ model_type: 'pretraining' for VideoMAEForPreTraining, 'base' for VideoMAEModel
73
+ """
74
  processor = VideoMAEImageProcessor.from_pretrained(model_name)
75
  if model_type == 'base':
76
  model = VideoMAEModel.from_pretrained(model_name)
 
79
  model = model.to(device)
80
  return model, processor
81
 
82
+ # Global model and processor
83
+ model = None
84
+ processor = None
85
+
86
+ def initialize_model(model_name='MCG-NJU/videomae-base'):
87
+ """Initialize the model (call once at startup)"""
88
+ global model, processor
89
+ if model is None:
90
+ print(f"Loading model: {model_name}")
91
+ model, processor = load_model(model_name)
92
+ print("Model loaded successfully")
93
+ return model, processor
94
+
95
+ def visualize_attention(video_frames, model, processor, layer_idx=-1):
96
+ """
97
+ Visualize attention maps from VideoMAE model.
98
+ Returns PIL Image for Gradio.
99
+ """
100
  inputs = processor(video_frames, return_tensors="pt")
101
  pixel_values = inputs['pixel_values'].to(device)
102
  batch_size, time, num_channels, height, width = pixel_values.shape
 
105
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
106
  num_temporal_patches = time // tubelet_size
107
 
108
+ # Use VideoMAEModel to get attention weights
109
  if hasattr(model, 'videomae'):
110
  encoder_model = model.videomae
111
  else:
112
  encoder_model = model
113
 
114
+ # Disable SDPA and use eager attention
115
  original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
116
  encoder_model.config._attn_implementation = "eager"
117
 
 
128
  attention_weights = attentions[layer_idx][0]
129
  avg_attn = attention_weights.mean(dim=0)
130
 
131
+ # Unnormalize frames
132
  dtype = pixel_values.dtype
133
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
134
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
 
150
  avg_attn_received = avg_attn.mean(dim=0)
151
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
152
 
153
+ # Create visualization for first frame
154
  frame_idx = 0
155
  frame_img = frames_unnorm[frame_idx * tubelet_size]
156
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
157
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
158
  attn_map_upsampled = cv2.resize(attn_map, (width, height))
159
 
160
+ # Create overlay
161
  fig, ax = plt.subplots(1, 1, figsize=(10, 10))
162
  ax.imshow(frame_img)
163
  ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet')
 
166
 
167
  return fig_to_image(fig)
168
 
169
+ def visualize_latent(video_frames, model, processor):
170
+ """
171
+ Visualize latent space representations from VideoMAE model.
172
+ Returns PIL Image for Gradio.
173
+ """
174
  inputs = processor(video_frames, return_tensors="pt")
175
  pixel_values = inputs['pixel_values'].to(device)
 
 
 
 
 
176
 
177
  if hasattr(model, 'videomae'):
178
  encoder_model = model.videomae
179
  else:
180
  encoder_model = model
181
 
182
+ outputs = encoder_model(pixel_values, output_hidden_states=True)
183
+ hidden_states = outputs.last_hidden_state[0]
 
 
 
 
 
 
 
 
 
 
184
 
185
+ batch_size, time, num_channels, height, width = pixel_values.shape
186
+ tubelet_size = model.config.tubelet_size
187
+ patch_size = model.config.patch_size
188
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
189
+ num_temporal_patches = time // tubelet_size
190
 
191
+ # Unnormalize frames
192
  dtype = pixel_values.dtype
193
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
194
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
 
196
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
197
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
198
 
199
+ seq_len = hidden_states.shape[0]
 
 
200
  expected_seq_len = num_temporal_patches * num_patches_per_frame
201
 
202
  if seq_len != expected_seq_len:
203
  if seq_len % num_patches_per_frame == 0:
204
  num_temporal_patches = seq_len // num_patches_per_frame
205
  else:
206
+ raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
207
 
208
+ hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
209
+ hidden_size = hidden_states_reshaped.shape[-1]
210
+ hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
211
 
212
+ pca = PCA(n_components=3)
213
+ pca_components = pca.fit_transform(hidden_states_flat)
214
+ pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
 
 
 
215
 
216
+ H_p = int(np.sqrt(num_patches_per_frame))
217
+ W_p = H_p
218
+
219
+ if H_p * W_p == num_patches_per_frame:
220
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
221
+ else:
222
+ factors = []
223
+ for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
224
+ if num_patches_per_frame % i == 0:
225
+ factors.append((i, num_patches_per_frame // i))
226
+ if factors:
227
+ H_p, W_p = factors[-1]
228
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
229
+ else:
230
+ raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
231
+
232
+ # Normalize components
233
+ for t in range(num_temporal_patches):
234
+ for c in range(3):
235
+ comp = pca_spatial[t, :, :, c]
236
+ comp_min, comp_max = comp.min(), comp.max()
237
+ if comp_max > comp_min:
238
+ pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
239
+ else:
240
+ pca_spatial[t, :, :, c] = 0.5
241
+
242
+ # Show first frame
243
+ frame_idx = 0
244
+ frame_img = frames_unnorm[frame_idx * tubelet_size]
245
+ rgb_image = pca_spatial[frame_idx]
246
+ upscale_factor = 8
247
+ rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
248
+
249
+ fig = plt.figure(figsize=(6,6))
250
+ ax = fig.add_subplot(1, 1, 1)
251
+ ax.imshow(rgb_image_upscaled)
252
+ ax.set_title(f"PCA Components (RGB = PC1, PC2, PC3)")
253
+ ax.axis('off')
254
+ plt.suptitle(f"Explained Variance: {pca.explained_variance_ratio_.sum():.2%}", fontsize=12)
255
+ plt.tight_layout()
256
+
257
+ return fig_to_image(fig)
258
 
259
+ def compute_reconstruction_all_frames(video_frames, model, processor):
260
+ """
261
+ Compute reconstruction for all frames and return as numpy arrays.
262
+ Returns: (original_frames, reconstructed_frames) as numpy arrays
263
+ """
264
  inputs = processor(video_frames, return_tensors="pt")
265
  T, C, H, W = inputs['pixel_values'][0].shape
266
  tubelet_size = model.config.tubelet_size
 
332
 
333
  return original_frames, reconstructed_frames_np
334
 
335
+ def visualize_reconstruction(video_frames, model, processor):
336
+ """
337
+ Visualize reconstruction from VideoMAE model.
338
+ Returns PIL Image for Gradio.
339
+ """
340
  inputs = processor(video_frames, return_tensors="pt")
341
+ T, C, H, W = inputs['pixel_values'][0].shape
342
+ tubelet_size = model.config.tubelet_size
343
+ patch_size = model.config.patch_size
344
+ T = T//tubelet_size
345
 
346
+ num_patches = (model.config.image_size // model.config.patch_size) ** 2
347
+ num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size))
348
+ total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches
349
+ batch_size = inputs['pixel_values'].shape[0]
350
+ bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool)
351
+
352
+ for b in range(batch_size):
353
+ mask_indices = np.random.choice(total_patches, num_masked, replace=False)
354
+ bool_masked_pos[b, mask_indices] = True
355
+
356
+ inputs['bool_masked_pos'] = bool_masked_pos.to(device)
357
+ inputs['pixel_values'] = inputs['pixel_values'].to(device)
358
 
359
+ outputs = model(**inputs)
360
+ logits = outputs.logits
361
 
362
+ pixel_values = inputs['pixel_values']
363
  batch_size, time, num_channels, height, width = pixel_values.shape
364
  tubelet_size = model.config.tubelet_size
365
  patch_size = model.config.patch_size
366
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
367
  num_temporal_patches = time // tubelet_size
368
+ total_patches = num_temporal_patches * num_patches_per_frame
369
 
370
  dtype = pixel_values.dtype
371
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
372
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
373
  frames_unnorm = pixel_values * std + mean
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ frames_patched = frames_unnorm.view(
376
+ batch_size, time // tubelet_size, tubelet_size, num_channels,
377
+ height // patch_size, patch_size, width // patch_size, patch_size,
378
+ )
379
+ frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
380
+ videos_patch = frames_patched.view(
381
+ batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels,
382
+ )
383
 
384
+ if model.config.norm_pix_loss:
385
+ patch_mean = videos_patch.mean(dim=-2, keepdim=True)
386
+ patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
387
+ logits_denorm = logits * patch_std + patch_mean
388
+ else:
389
+ logits_denorm = torch.clamp(logits, 0.0, 1.0)
390
 
391
+ reconstructed_patches = videos_patch.clone()
392
+ reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels)
393
 
394
+ reconstructed_patches_reshaped = reconstructed_patches.view(
395
+ batch_size, time // tubelet_size, height // patch_size, width // patch_size,
396
+ tubelet_size, patch_size, patch_size, num_channels,
397
+ )
398
+ reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous()
399
+ reconstructed_frames = reconstructed_patches_reshaped.view(
400
+ batch_size, time, num_channels, height, width,
401
+ )
 
 
 
 
402
 
403
+ original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
404
+ reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy()
 
 
 
 
 
 
405
 
406
+ original_frames = np.clip(original_frames, 0, 1)
407
+ reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1)
 
 
 
 
408
 
409
+ # Show first frame
410
+ frame_idx = 0
411
+ fig = plt.figure(figsize=(6,6))
412
+ ax = plt.subplot(111)
 
 
 
 
 
 
413
 
414
+ ax.imshow(reconstructed_frames_np[frame_idx * tubelet_size])
415
+ ax.set_title(f"Reconstructed Frame: {frame_idx * tubelet_size}")
416
+ ax.axis('off')
 
 
 
 
417
 
418
+ return fig_to_image(fig)
419
+
420
+ def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
421
+ """
422
+ Compute attention maps for all frames.
423
+ Returns: (original_frames, attention_maps) as numpy arrays
424
+ """
425
+ inputs = processor(video_frames, return_tensors="pt")
426
+ pixel_values = inputs['pixel_values'].to(device)
427
+ batch_size, time, num_channels, height, width = pixel_values.shape
428
  tubelet_size = model.config.tubelet_size
429
  patch_size = model.config.patch_size
430
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
431
+ num_temporal_patches = time // tubelet_size
432
 
433
+ if hasattr(model, 'videomae'):
434
+ encoder_model = model.videomae
435
+ else:
436
+ encoder_model = model
437
+
438
+ original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
439
+ encoder_model.config._attn_implementation = "eager"
440
 
441
  try:
442
+ outputs = encoder_model(pixel_values, output_attentions=True)
443
  finally:
444
  if original_attn_impl is not None:
445
+ encoder_model.config._attn_implementation = original_attn_impl
446
 
447
  attentions = outputs.attentions
448
  if layer_idx < 0:
 
451
  attention_weights = attentions[layer_idx][0]
452
  avg_attn = attention_weights.mean(dim=0)
453
 
454
+ dtype = pixel_values.dtype
455
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
456
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
457
+ frames_unnorm = pixel_values * std + mean
458
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
459
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
460
 
 
472
  avg_attn_received = avg_attn.mean(dim=0)
473
  attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
474
 
475
+ # Create attention maps for all temporal patches
476
  attention_maps = []
477
  for frame_idx in range(num_temporal_patches):
478
  attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
 
482
 
483
  return frames_unnorm, attention_maps
484
 
485
+ def compute_latent_all_frames(video_frames, model, processor):
486
+ """
487
+ Compute PCA latent visualizations for all frames.
488
+ Returns: (original_frames, pca_images) as numpy arrays
489
+ """
490
+ inputs = processor(video_frames, return_tensors="pt")
491
+ pixel_values = inputs['pixel_values'].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
+ if hasattr(model, 'videomae'):
494
+ encoder_model = model.videomae
495
+ else:
496
+ encoder_model = model
497
 
498
+ outputs = encoder_model(pixel_values, output_hidden_states=True)
499
  hidden_states = outputs.last_hidden_state[0]
500
 
501
+ batch_size, time, num_channels, height, width = pixel_values.shape
502
  tubelet_size = model.config.tubelet_size
503
  patch_size = model.config.patch_size
504
  num_patches_per_frame = (height // patch_size) * (width // patch_size)
505
+ num_temporal_patches = time // tubelet_size
506
 
507
+ dtype = pixel_values.dtype
508
  mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
509
  std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
510
+ frames_unnorm = pixel_values * std + mean
511
  frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
512
  frames_unnorm = np.clip(frames_unnorm, 0, 1)
513
 
 
544
  else:
545
  raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
546
 
547
+ # Normalize components
548
  for t in range(num_temporal_patches):
549
  for c in range(3):
550
  comp = pca_spatial[t, :, :, c]
 
554
  else:
555
  pca_spatial[t, :, :, c] = 0.5
556
 
557
+ # Create upscaled images for all frames
558
  upscale_factor = 8
559
  pca_images = []
560
  for t_idx in range(num_temporal_patches):
 
564
 
565
  return frames_unnorm, pca_images
566
 
567
+ # Dummy function for backward compatibility
568
+ def process_video(video_path):
569
+ cap = cv2.VideoCapture(video_path)
570
+ frames = []
571
+ while cap.isOpened():
572
+ ret, frame = cap.read()
573
+ if not ret:
574
+ break
575
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
576
+ frames.append(frame)
577
+ cap.release()
578
+ visualizations = [cv2.applyColorMap((f * 0.5).astype(np.uint8), cv2.COLORMAP_JET) for f in frames]
579
+ return frames, visualizations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
  # Global state to store frames after upload
582
  stored_frames = []
583
+ stored_viz = []
584
+
585
+ # Cache for visualization results: {video_path: {mode: {frame_idx: image}}}
586
  visualization_cache = {}
587
  current_video_path = None
588
 
589
+ def on_upload(video_path, mode):
590
+ global stored_frames, stored_viz, model, processor, visualization_cache, current_video_path
591
  if video_path is None:
592
  return gr.update(maximum=0), None, None
593
 
594
  # Initialize model if needed
595
+ if model is None:
596
+ model, processor = initialize_model()
597
 
598
+ # Check if we need to recompute (new video or mode not cached)
599
  video_path_str = str(video_path)
 
600
  need_to_load_video = (video_path_str != current_video_path)
601
+ need_to_compute_mode = (video_path_str not in visualization_cache or mode not in visualization_cache[video_path_str])
602
 
603
  if need_to_load_video:
604
+ # Load video frames
 
 
605
  print(f"Loading video: {video_path_str}")
606
+ video_frames = load_video(video_path)
607
  stored_frames = video_frames
608
  current_video_path = video_path_str
609
  else:
610
+ # Reuse already loaded frames
611
  video_frames = stored_frames
612
 
613
+ # Initialize cache for this video
614
+ if video_path_str not in visualization_cache:
615
+ visualization_cache[video_path_str] = {}
616
 
617
  if need_to_compute_mode:
618
+ # Compute all visualizations and cache them
619
+ print(f"Computing {mode} visualization for all frames...")
620
  num_frames = len(stored_frames)
621
  tubelet_size = model.config.tubelet_size
622
 
623
  if mode == "reconstruction":
624
+ original_frames, reconstructed_frames = compute_reconstruction_all_frames(video_frames, model, processor)
625
+ # Cache as images per frame - map model frames to stored frames
626
+ visualization_cache[video_path_str][mode] = {}
627
+ for i in range(num_frames):
628
+ # Map stored frame index to model frame index
629
+ model_frame_idx = min(i, len(reconstructed_frames) - 1)
630
+ fig = plt.figure(figsize=(6, 6))
631
+ ax = plt.subplot(111)
632
+ ax.imshow(reconstructed_frames[model_frame_idx])
633
+ ax.set_title(f"Reconstructed Frame: {i}")
634
+ ax.axis('off')
635
+ visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
 
 
 
 
 
 
 
 
 
 
636
 
637
  elif mode == "attention":
638
+ original_frames, attention_maps = compute_attention_all_frames(video_frames, model, processor)
639
+ visualization_cache[video_path_str][mode] = {}
 
 
 
 
640
  for i in range(num_frames):
641
+ # Map stored frame to temporal patch
642
  temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1)
643
  model_frame_idx = min(i, len(original_frames) - 1)
644
  if temporal_patch_idx < len(attention_maps):
 
648
  ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet')
649
  ax.set_title(f"Attention Map - Frame {i}")
650
  ax.axis('off')
651
+ visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
652
 
653
  elif mode == "latent":
654
+ original_frames, pca_images = compute_latent_all_frames(video_frames, model, processor)
655
+ visualization_cache[video_path_str][mode] = {}
 
 
 
 
656
  for i in range(num_frames):
657
+ # Map stored frame to temporal patch
658
  temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1)
659
  if temporal_patch_idx < len(pca_images):
660
  fig = plt.figure(figsize=(6, 6))
 
662
  ax.imshow(pca_images[temporal_patch_idx])
663
  ax.set_title(f"PCA Components - Frame {i}")
664
  ax.axis('off')
665
+ visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
666
 
667
+ print(f"Caching complete for {mode} mode")
668
 
669
  # Load from cache
670
  max_idx = len(stored_frames) - 1
671
  frame_idx = 0
672
 
673
+ # Get original frame
674
  if isinstance(stored_frames[0], Image.Image):
675
  first_frame = np.array(stored_frames[0])
676
  else:
677
  first_frame = stored_frames[0]
678
 
679
+ # Get visualization from cache
680
+ if video_path_str in visualization_cache and mode in visualization_cache[video_path_str]:
681
+ if frame_idx in visualization_cache[video_path_str][mode]:
682
+ viz_img = visualization_cache[video_path_str][mode][frame_idx]
683
  else:
684
+ # Fallback if frame not in cache
685
  viz_img = Image.fromarray(first_frame)
686
  else:
687
+ # Fallback if not cached
688
  viz_img = Image.fromarray(first_frame)
689
 
690
  return gr.update(maximum=max_idx, value=0), first_frame, viz_img
691
 
692
+ def update_frame(idx, mode):
693
  global stored_frames, visualization_cache, current_video_path
694
  if not stored_frames:
695
  return None, None
 
698
  if frame_idx >= len(stored_frames):
699
  frame_idx = len(stored_frames) - 1
700
 
701
+ # Get frame
702
  if isinstance(stored_frames[frame_idx], Image.Image):
703
  frame = np.array(stored_frames[frame_idx])
704
  else:
705
  frame = stored_frames[frame_idx]
706
 
707
+ # Load visualization from cache (fast!)
708
  video_path_str = current_video_path
709
+ if video_path_str and video_path_str in visualization_cache:
710
+ if mode in visualization_cache[video_path_str]:
711
+ if frame_idx in visualization_cache[video_path_str][mode]:
712
+ viz_img = visualization_cache[video_path_str][mode][frame_idx]
 
713
  else:
714
+ # Fallback if frame not in cache
715
  viz_img = Image.fromarray(frame)
716
  else:
717
+ # Mode not cached, return frame
718
  viz_img = Image.fromarray(frame)
719
  else:
720
+ # Not cached, return frame
721
  viz_img = Image.fromarray(frame)
722
 
723
  return frame, viz_img
724
 
725
+
726
  def load_example_video(video_file):
727
+ def _load_example_video( mode):
728
  """Load the predefined example video"""
729
  example_path = f"examples/{video_file}"
730
+ return on_upload(example_path, mode)
731
 
732
  return _load_example_video
733
 
 
 
 
 
 
 
734
  # --- Gradio UI Layout ---
735
+ with gr.Blocks(title="VideoMAE Representation Explorer") as demo:
736
+ gr.Markdown("## 🎥 VideoMAE Frame-by-Frame Representation Explorer")
737
+
738
+ mode_radio = gr.Radio(
739
+ choices=["reconstruction", "attention", "latent"],
740
+ value="reconstruction",
741
+ label="Visualization Mode",
742
+ info="Choose the type of visualization"
743
+ )
 
 
 
 
 
 
 
744
 
745
  with gr.Row():
746
  with gr.Column():
 
752
 
753
  # Event Listeners
754
  video_lists = os.listdir("examples") if os.path.exists("examples") else []
755
+ video_lists = os.listdir("app/examples") if os.path.exists("app/examples") else video_lists
756
  with gr.Row():
757
  video_input = gr.Video(label="Upload Video")
758
  with gr.Column():
759
  for video_file in video_lists:
760
  load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary")
761
+ load_example_btn.click(load_example_video(video_file), inputs=mode_radio, outputs=[frame_slider, orig_output, viz_output])
 
 
 
 
762
 
763
+ # load_example_btn = gr.Button("Load Example Video (dog.mp4)", variant="secondary")
764
+ video_input.change(on_upload, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
 
 
 
 
 
 
 
 
 
765
 
766
+ frame_slider.change(update_frame, inputs=[frame_slider, mode_radio], outputs=[orig_output, viz_output])
767
+ def on_mode_change(video_path, mode):
768
+ """Handle mode change - compute if not cached, otherwise use cache"""
769
+ global stored_frames, model, processor, visualization_cache, current_video_path
770
  if video_path is None:
771
  return gr.update(maximum=0), None, None
772
+
773
+ video_path_str = str(video_path)
774
+
775
+ # If video is already loaded and mode is cached, just return cached result
776
+ if video_path_str == current_video_path and video_path_str in visualization_cache:
777
+ if mode in visualization_cache[video_path_str]:
778
+ max_idx = len(stored_frames) - 1
779
+ frame_idx = 0
780
+ if isinstance(stored_frames[0], Image.Image):
781
+ first_frame = np.array(stored_frames[0])
782
+ else:
783
+ first_frame = stored_frames[0]
784
+ if frame_idx in visualization_cache[video_path_str][mode]:
785
+ viz_img = visualization_cache[video_path_str][mode][frame_idx]
786
+ else:
787
+ viz_img = Image.fromarray(first_frame)
788
+ return gr.update(maximum=max_idx, value=0), first_frame, viz_img
789
+
790
+ # Otherwise, compute (will use cached video frames if available)
791
+ return on_upload(video_path, mode)
792
 
793
+ mode_radio.change(on_mode_change, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
 
 
 
 
794
 
795
  if __name__ == "__main__":
796
+ # Initialize model at startup
797
+ initialize_model()
798
+ demo.launch()