erow commited on
Commit
dc22224
·
1 Parent(s): 38e3a8d
Files changed (2) hide show
  1. README.md +8 -1
  2. app.py +792 -0
README.md CHANGED
@@ -10,4 +10,11 @@ app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
10
  pinned: false
11
  ---
12
 
13
+ This is a demo of the VideoMAE model to visualize the attention map, latent space, and reconstruction of a video.
14
+
15
+ Choose one of the following modes to visualize the video:
16
+ - Reconstruction: Reconstruct the video by masking 90% of the patches and reconstructing the masked patches.
17
+ - Attention: Visualize the average attention map of the last layer.
18
+ - Latent: Visualize the PCA components of the latent space of the video.
19
+
20
+ You can choose the model and load the example video or upload your own video to visualize the video.
app.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from pathlib import Path
8
+ from sklearn.decomposition import PCA
9
+ from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel
10
+ from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ import matplotlib
12
+ matplotlib.use('Agg') # Use non-interactive backend
13
+ import matplotlib.pyplot as plt
14
+ import io
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Helper function to convert matplotlib figure to PIL Image
19
+ def fig_to_image(fig):
20
+ """Convert matplotlib figure to PIL Image"""
21
+ buf = io.BytesIO()
22
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
23
+ buf.seek(0)
24
+ img = Image.open(buf)
25
+ plt.close(fig)
26
+ return img
27
+
28
+ def load_video(video_path, num_frames=16, sample_rate=4):
29
+ """
30
+ Load video from file path.
31
+ Returns list of PIL Images or numpy arrays.
32
+ """
33
+ video_path = Path(video_path)
34
+
35
+ if not video_path.exists():
36
+ raise FileNotFoundError(f"Video file not found: {video_path}")
37
+
38
+ # Try to load as video file
39
+ cap = cv2.VideoCapture(str(video_path))
40
+ frames = []
41
+
42
+ if not cap.isOpened():
43
+ raise ValueError(f"Could not open video file: {video_path}")
44
+
45
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
46
+ if sample_rate * num_frames > frame_count:
47
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
48
+ print(f"warning: only {num_frames} frames are sampled from {frame_count} frames")
49
+ else:
50
+ frame_indices = np.arange(0, sample_rate * num_frames, sample_rate)
51
+
52
+ print(f"Sampling {frame_indices}")
53
+ for idx in frame_indices:
54
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
55
+ ret, frame = cap.read()
56
+ if ret:
57
+ # Convert BGR to RGB
58
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
59
+ frames.append(Image.fromarray(frame))
60
+ print(f"Loaded {len(frames)} frames")
61
+ cap.release()
62
+ return frames
63
+
64
+ def load_model(model_name, model_type='pretraining'):
65
+ """
66
+ Load model and processor by name.
67
+ model_type: 'pretraining' for VideoMAEForPreTraining, 'base' for VideoMAEModel
68
+ """
69
+ processor = VideoMAEImageProcessor.from_pretrained(model_name)
70
+ if model_type == 'base':
71
+ model = VideoMAEModel.from_pretrained(model_name)
72
+ else:
73
+ model = VideoMAEForPreTraining.from_pretrained(model_name)
74
+ model = model.to(device)
75
+ return model, processor
76
+
77
+ # Global model and processor
78
+ model = None
79
+ processor = None
80
+
81
+ def initialize_model(model_name='MCG-NJU/videomae-base'):
82
+ """Initialize the model (call once at startup)"""
83
+ global model, processor
84
+ if model is None:
85
+ print(f"Loading model: {model_name}")
86
+ model, processor = load_model(model_name)
87
+ print("Model loaded successfully")
88
+ return model, processor
89
+
90
+ def visualize_attention(video_frames, model, processor, layer_idx=-1):
91
+ """
92
+ Visualize attention maps from VideoMAE model.
93
+ Returns PIL Image for Gradio.
94
+ """
95
+ inputs = processor(video_frames, return_tensors="pt")
96
+ pixel_values = inputs['pixel_values'].to(device)
97
+ batch_size, time, num_channels, height, width = pixel_values.shape
98
+ tubelet_size = model.config.tubelet_size
99
+ patch_size = model.config.patch_size
100
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
101
+ num_temporal_patches = time // tubelet_size
102
+
103
+ # Use VideoMAEModel to get attention weights
104
+ if hasattr(model, 'videomae'):
105
+ encoder_model = model.videomae
106
+ else:
107
+ encoder_model = model
108
+
109
+ # Disable SDPA and use eager attention
110
+ original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
111
+ encoder_model.config._attn_implementation = "eager"
112
+
113
+ try:
114
+ outputs = encoder_model(pixel_values, output_attentions=True)
115
+ finally:
116
+ if original_attn_impl is not None:
117
+ encoder_model.config._attn_implementation = original_attn_impl
118
+
119
+ attentions = outputs.attentions
120
+ if layer_idx < 0:
121
+ layer_idx = len(attentions) + layer_idx
122
+
123
+ attention_weights = attentions[layer_idx][0]
124
+ avg_attn = attention_weights.mean(dim=0)
125
+
126
+ # Unnormalize frames
127
+ dtype = pixel_values.dtype
128
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
129
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
130
+ frames_unnorm = pixel_values * std + mean
131
+ frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
132
+ frames_unnorm = np.clip(frames_unnorm, 0, 1)
133
+
134
+ seq_len = avg_attn.shape[0]
135
+ H_p = height // patch_size
136
+ W_p = width // patch_size
137
+ expected_seq_len = num_temporal_patches * num_patches_per_frame
138
+
139
+ if seq_len != expected_seq_len:
140
+ if seq_len % num_patches_per_frame == 0:
141
+ num_temporal_patches = seq_len // num_patches_per_frame
142
+ else:
143
+ raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}")
144
+
145
+ avg_attn_received = avg_attn.mean(dim=0)
146
+ attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
147
+
148
+ # Create visualization for first frame
149
+ frame_idx = 0
150
+ frame_img = frames_unnorm[frame_idx * tubelet_size]
151
+ attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
152
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
153
+ attn_map_upsampled = cv2.resize(attn_map, (width, height))
154
+
155
+ # Create overlay
156
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
157
+ ax.imshow(frame_img)
158
+ ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet')
159
+ ax.set_title(f"Attention Map - Frame {frame_idx * tubelet_size}")
160
+ ax.axis('off')
161
+
162
+ return fig_to_image(fig)
163
+
164
+ def visualize_latent(video_frames, model, processor):
165
+ """
166
+ Visualize latent space representations from VideoMAE model.
167
+ Returns PIL Image for Gradio.
168
+ """
169
+ inputs = processor(video_frames, return_tensors="pt")
170
+ pixel_values = inputs['pixel_values'].to(device)
171
+
172
+ if hasattr(model, 'videomae'):
173
+ encoder_model = model.videomae
174
+ else:
175
+ encoder_model = model
176
+
177
+ outputs = encoder_model(pixel_values, output_hidden_states=True)
178
+ hidden_states = outputs.last_hidden_state[0]
179
+
180
+ batch_size, time, num_channels, height, width = pixel_values.shape
181
+ tubelet_size = model.config.tubelet_size
182
+ patch_size = model.config.patch_size
183
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
184
+ num_temporal_patches = time // tubelet_size
185
+
186
+ # Unnormalize frames
187
+ dtype = pixel_values.dtype
188
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
189
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
190
+ frames_unnorm = pixel_values * std + mean
191
+ frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
192
+ frames_unnorm = np.clip(frames_unnorm, 0, 1)
193
+
194
+ seq_len = hidden_states.shape[0]
195
+ expected_seq_len = num_temporal_patches * num_patches_per_frame
196
+
197
+ if seq_len != expected_seq_len:
198
+ if seq_len % num_patches_per_frame == 0:
199
+ num_temporal_patches = seq_len // num_patches_per_frame
200
+ else:
201
+ raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
202
+
203
+ hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
204
+ hidden_size = hidden_states_reshaped.shape[-1]
205
+ hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
206
+
207
+ pca = PCA(n_components=3)
208
+ pca_components = pca.fit_transform(hidden_states_flat)
209
+ pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
210
+
211
+ H_p = int(np.sqrt(num_patches_per_frame))
212
+ W_p = H_p
213
+
214
+ if H_p * W_p == num_patches_per_frame:
215
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
216
+ else:
217
+ factors = []
218
+ for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
219
+ if num_patches_per_frame % i == 0:
220
+ factors.append((i, num_patches_per_frame // i))
221
+ if factors:
222
+ H_p, W_p = factors[-1]
223
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
224
+ else:
225
+ raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
226
+
227
+ # Normalize components
228
+ for t in range(num_temporal_patches):
229
+ for c in range(3):
230
+ comp = pca_spatial[t, :, :, c]
231
+ comp_min, comp_max = comp.min(), comp.max()
232
+ if comp_max > comp_min:
233
+ pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
234
+ else:
235
+ pca_spatial[t, :, :, c] = 0.5
236
+
237
+ # Show first frame
238
+ frame_idx = 0
239
+ frame_img = frames_unnorm[frame_idx * tubelet_size]
240
+ rgb_image = pca_spatial[frame_idx]
241
+ upscale_factor = 8
242
+ rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
243
+
244
+ fig = plt.figure(figsize=(6,6))
245
+ ax = fig.add_subplot(1, 1, 1)
246
+ ax.imshow(rgb_image_upscaled)
247
+ ax.set_title(f"PCA Components (RGB = PC1, PC2, PC3)")
248
+ ax.axis('off')
249
+ plt.suptitle(f"Explained Variance: {pca.explained_variance_ratio_.sum():.2%}", fontsize=12)
250
+ plt.tight_layout()
251
+
252
+ return fig_to_image(fig)
253
+
254
+ def compute_reconstruction_all_frames(video_frames, model, processor):
255
+ """
256
+ Compute reconstruction for all frames and return as numpy arrays.
257
+ Returns: (original_frames, reconstructed_frames) as numpy arrays
258
+ """
259
+ inputs = processor(video_frames, return_tensors="pt")
260
+ T, C, H, W = inputs['pixel_values'][0].shape
261
+ tubelet_size = model.config.tubelet_size
262
+ patch_size = model.config.patch_size
263
+ T = T//tubelet_size
264
+
265
+ num_patches = (model.config.image_size // model.config.patch_size) ** 2
266
+ num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size))
267
+ total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches
268
+ batch_size = inputs['pixel_values'].shape[0]
269
+ bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool)
270
+
271
+ for b in range(batch_size):
272
+ mask_indices = np.random.choice(total_patches, num_masked, replace=False)
273
+ bool_masked_pos[b, mask_indices] = True
274
+
275
+ inputs['bool_masked_pos'] = bool_masked_pos.to(device)
276
+ inputs['pixel_values'] = inputs['pixel_values'].to(device)
277
+
278
+ outputs = model(**inputs)
279
+ logits = outputs.logits
280
+
281
+ pixel_values = inputs['pixel_values']
282
+ batch_size, time, num_channels, height, width = pixel_values.shape
283
+ tubelet_size = model.config.tubelet_size
284
+ patch_size = model.config.patch_size
285
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
286
+ num_temporal_patches = time // tubelet_size
287
+ total_patches = num_temporal_patches * num_patches_per_frame
288
+
289
+ dtype = pixel_values.dtype
290
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
291
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
292
+ frames_unnorm = pixel_values * std + mean
293
+
294
+ frames_patched = frames_unnorm.view(
295
+ batch_size, time // tubelet_size, tubelet_size, num_channels,
296
+ height // patch_size, patch_size, width // patch_size, patch_size,
297
+ )
298
+ frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
299
+ videos_patch = frames_patched.view(
300
+ batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels,
301
+ )
302
+
303
+ if model.config.norm_pix_loss:
304
+ patch_mean = videos_patch.mean(dim=-2, keepdim=True)
305
+ patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
306
+ logits_denorm = logits * patch_std + patch_mean
307
+ else:
308
+ logits_denorm = torch.clamp(logits, 0.0, 1.0)
309
+
310
+ reconstructed_patches = videos_patch.clone()
311
+ reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels)
312
+
313
+ reconstructed_patches_reshaped = reconstructed_patches.view(
314
+ batch_size, time // tubelet_size, height // patch_size, width // patch_size,
315
+ tubelet_size, patch_size, patch_size, num_channels,
316
+ )
317
+ reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous()
318
+ reconstructed_frames = reconstructed_patches_reshaped.view(
319
+ batch_size, time, num_channels, height, width,
320
+ )
321
+
322
+ original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
323
+ reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy()
324
+
325
+ original_frames = np.clip(original_frames, 0, 1)
326
+ reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1)
327
+
328
+ return original_frames, reconstructed_frames_np
329
+
330
+ def visualize_reconstruction(video_frames, model, processor):
331
+ """
332
+ Visualize reconstruction from VideoMAE model.
333
+ Returns PIL Image for Gradio.
334
+ """
335
+ inputs = processor(video_frames, return_tensors="pt")
336
+ T, C, H, W = inputs['pixel_values'][0].shape
337
+ tubelet_size = model.config.tubelet_size
338
+ patch_size = model.config.patch_size
339
+ T = T//tubelet_size
340
+
341
+ num_patches = (model.config.image_size // model.config.patch_size) ** 2
342
+ num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size))
343
+ total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches
344
+ batch_size = inputs['pixel_values'].shape[0]
345
+ bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool)
346
+
347
+ for b in range(batch_size):
348
+ mask_indices = np.random.choice(total_patches, num_masked, replace=False)
349
+ bool_masked_pos[b, mask_indices] = True
350
+
351
+ inputs['bool_masked_pos'] = bool_masked_pos.to(device)
352
+ inputs['pixel_values'] = inputs['pixel_values'].to(device)
353
+
354
+ outputs = model(**inputs)
355
+ logits = outputs.logits
356
+
357
+ pixel_values = inputs['pixel_values']
358
+ batch_size, time, num_channels, height, width = pixel_values.shape
359
+ tubelet_size = model.config.tubelet_size
360
+ patch_size = model.config.patch_size
361
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
362
+ num_temporal_patches = time // tubelet_size
363
+ total_patches = num_temporal_patches * num_patches_per_frame
364
+
365
+ dtype = pixel_values.dtype
366
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
367
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
368
+ frames_unnorm = pixel_values * std + mean
369
+
370
+ frames_patched = frames_unnorm.view(
371
+ batch_size, time // tubelet_size, tubelet_size, num_channels,
372
+ height // patch_size, patch_size, width // patch_size, patch_size,
373
+ )
374
+ frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
375
+ videos_patch = frames_patched.view(
376
+ batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels,
377
+ )
378
+
379
+ if model.config.norm_pix_loss:
380
+ patch_mean = videos_patch.mean(dim=-2, keepdim=True)
381
+ patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
382
+ logits_denorm = logits * patch_std + patch_mean
383
+ else:
384
+ logits_denorm = torch.clamp(logits, 0.0, 1.0)
385
+
386
+ reconstructed_patches = videos_patch.clone()
387
+ reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels)
388
+
389
+ reconstructed_patches_reshaped = reconstructed_patches.view(
390
+ batch_size, time // tubelet_size, height // patch_size, width // patch_size,
391
+ tubelet_size, patch_size, patch_size, num_channels,
392
+ )
393
+ reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous()
394
+ reconstructed_frames = reconstructed_patches_reshaped.view(
395
+ batch_size, time, num_channels, height, width,
396
+ )
397
+
398
+ original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
399
+ reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy()
400
+
401
+ original_frames = np.clip(original_frames, 0, 1)
402
+ reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1)
403
+
404
+ # Show first frame
405
+ frame_idx = 0
406
+ fig = plt.figure(figsize=(6,6))
407
+ ax = plt.subplot(111)
408
+
409
+ ax.imshow(reconstructed_frames_np[frame_idx * tubelet_size])
410
+ ax.set_title(f"Reconstructed Frame: {frame_idx * tubelet_size}")
411
+ ax.axis('off')
412
+
413
+ return fig_to_image(fig)
414
+
415
+ def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
416
+ """
417
+ Compute attention maps for all frames.
418
+ Returns: (original_frames, attention_maps) as numpy arrays
419
+ """
420
+ inputs = processor(video_frames, return_tensors="pt")
421
+ pixel_values = inputs['pixel_values'].to(device)
422
+ batch_size, time, num_channels, height, width = pixel_values.shape
423
+ tubelet_size = model.config.tubelet_size
424
+ patch_size = model.config.patch_size
425
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
426
+ num_temporal_patches = time // tubelet_size
427
+
428
+ if hasattr(model, 'videomae'):
429
+ encoder_model = model.videomae
430
+ else:
431
+ encoder_model = model
432
+
433
+ original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
434
+ encoder_model.config._attn_implementation = "eager"
435
+
436
+ try:
437
+ outputs = encoder_model(pixel_values, output_attentions=True)
438
+ finally:
439
+ if original_attn_impl is not None:
440
+ encoder_model.config._attn_implementation = original_attn_impl
441
+
442
+ attentions = outputs.attentions
443
+ if layer_idx < 0:
444
+ layer_idx = len(attentions) + layer_idx
445
+
446
+ attention_weights = attentions[layer_idx][0]
447
+ avg_attn = attention_weights.mean(dim=0)
448
+
449
+ dtype = pixel_values.dtype
450
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
451
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
452
+ frames_unnorm = pixel_values * std + mean
453
+ frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
454
+ frames_unnorm = np.clip(frames_unnorm, 0, 1)
455
+
456
+ seq_len = avg_attn.shape[0]
457
+ H_p = height // patch_size
458
+ W_p = width // patch_size
459
+ expected_seq_len = num_temporal_patches * num_patches_per_frame
460
+
461
+ if seq_len != expected_seq_len:
462
+ if seq_len % num_patches_per_frame == 0:
463
+ num_temporal_patches = seq_len // num_patches_per_frame
464
+ else:
465
+ raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}")
466
+
467
+ avg_attn_received = avg_attn.mean(dim=0)
468
+ attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
469
+
470
+ # Create attention maps for all temporal patches
471
+ attention_maps = []
472
+ for frame_idx in range(num_temporal_patches):
473
+ attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
474
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
475
+ attn_map_upsampled = cv2.resize(attn_map, (width, height))
476
+ attention_maps.append(attn_map_upsampled)
477
+
478
+ return frames_unnorm, attention_maps
479
+
480
+ def compute_latent_all_frames(video_frames, model, processor):
481
+ """
482
+ Compute PCA latent visualizations for all frames.
483
+ Returns: (original_frames, pca_images) as numpy arrays
484
+ """
485
+ inputs = processor(video_frames, return_tensors="pt")
486
+ pixel_values = inputs['pixel_values'].to(device)
487
+
488
+ if hasattr(model, 'videomae'):
489
+ encoder_model = model.videomae
490
+ else:
491
+ encoder_model = model
492
+
493
+ outputs = encoder_model(pixel_values, output_hidden_states=True)
494
+ hidden_states = outputs.last_hidden_state[0]
495
+
496
+ batch_size, time, num_channels, height, width = pixel_values.shape
497
+ tubelet_size = model.config.tubelet_size
498
+ patch_size = model.config.patch_size
499
+ num_patches_per_frame = (height // patch_size) * (width // patch_size)
500
+ num_temporal_patches = time // tubelet_size
501
+
502
+ dtype = pixel_values.dtype
503
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
504
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
505
+ frames_unnorm = pixel_values * std + mean
506
+ frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
507
+ frames_unnorm = np.clip(frames_unnorm, 0, 1)
508
+
509
+ seq_len = hidden_states.shape[0]
510
+ expected_seq_len = num_temporal_patches * num_patches_per_frame
511
+
512
+ if seq_len != expected_seq_len:
513
+ if seq_len % num_patches_per_frame == 0:
514
+ num_temporal_patches = seq_len // num_patches_per_frame
515
+ else:
516
+ raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
517
+
518
+ hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
519
+ hidden_size = hidden_states_reshaped.shape[-1]
520
+ hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
521
+
522
+ pca = PCA(n_components=3)
523
+ pca_components = pca.fit_transform(hidden_states_flat)
524
+ pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
525
+
526
+ H_p = int(np.sqrt(num_patches_per_frame))
527
+ W_p = H_p
528
+
529
+ if H_p * W_p == num_patches_per_frame:
530
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
531
+ else:
532
+ factors = []
533
+ for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
534
+ if num_patches_per_frame % i == 0:
535
+ factors.append((i, num_patches_per_frame // i))
536
+ if factors:
537
+ H_p, W_p = factors[-1]
538
+ pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
539
+ else:
540
+ raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
541
+
542
+ # Normalize components
543
+ for t in range(num_temporal_patches):
544
+ for c in range(3):
545
+ comp = pca_spatial[t, :, :, c]
546
+ comp_min, comp_max = comp.min(), comp.max()
547
+ if comp_max > comp_min:
548
+ pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
549
+ else:
550
+ pca_spatial[t, :, :, c] = 0.5
551
+
552
+ # Create upscaled images for all frames
553
+ upscale_factor = 8
554
+ pca_images = []
555
+ for t_idx in range(num_temporal_patches):
556
+ rgb_image = pca_spatial[t_idx]
557
+ rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
558
+ pca_images.append(rgb_image_upscaled)
559
+
560
+ return frames_unnorm, pca_images
561
+
562
+ # Dummy function for backward compatibility
563
+ def process_video(video_path):
564
+ cap = cv2.VideoCapture(video_path)
565
+ frames = []
566
+ while cap.isOpened():
567
+ ret, frame = cap.read()
568
+ if not ret:
569
+ break
570
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
571
+ frames.append(frame)
572
+ cap.release()
573
+ visualizations = [cv2.applyColorMap((f * 0.5).astype(np.uint8), cv2.COLORMAP_JET) for f in frames]
574
+ return frames, visualizations
575
+
576
+ # Global state to store frames after upload
577
+ stored_frames = []
578
+ stored_viz = []
579
+
580
+ # Cache for visualization results: {video_path: {mode: {frame_idx: image}}}
581
+ visualization_cache = {}
582
+ current_video_path = None
583
+
584
+ def on_upload(video_path, mode):
585
+ global stored_frames, stored_viz, model, processor, visualization_cache, current_video_path
586
+ if video_path is None:
587
+ return gr.update(maximum=0), None, None
588
+
589
+ # Initialize model if needed
590
+ if model is None:
591
+ model, processor = initialize_model()
592
+
593
+ # Check if we need to recompute (new video or mode not cached)
594
+ video_path_str = str(video_path)
595
+ need_to_load_video = (video_path_str != current_video_path)
596
+ need_to_compute_mode = (video_path_str not in visualization_cache or mode not in visualization_cache[video_path_str])
597
+
598
+ if need_to_load_video:
599
+ # Load video frames
600
+ print(f"Loading video: {video_path_str}")
601
+ video_frames = load_video(video_path)
602
+ stored_frames = video_frames
603
+ current_video_path = video_path_str
604
+ else:
605
+ # Reuse already loaded frames
606
+ video_frames = stored_frames
607
+
608
+ # Initialize cache for this video
609
+ if video_path_str not in visualization_cache:
610
+ visualization_cache[video_path_str] = {}
611
+
612
+ if need_to_compute_mode:
613
+ # Compute all visualizations and cache them
614
+ print(f"Computing {mode} visualization for all frames...")
615
+ num_frames = len(stored_frames)
616
+ tubelet_size = model.config.tubelet_size
617
+
618
+ if mode == "reconstruction":
619
+ original_frames, reconstructed_frames = compute_reconstruction_all_frames(video_frames, model, processor)
620
+ # Cache as images per frame - map model frames to stored frames
621
+ visualization_cache[video_path_str][mode] = {}
622
+ for i in range(num_frames):
623
+ # Map stored frame index to model frame index
624
+ model_frame_idx = min(i, len(reconstructed_frames) - 1)
625
+ fig = plt.figure(figsize=(6, 6))
626
+ ax = plt.subplot(111)
627
+ ax.imshow(reconstructed_frames[model_frame_idx])
628
+ ax.set_title(f"Reconstructed Frame: {i}")
629
+ ax.axis('off')
630
+ visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
631
+
632
+ elif mode == "attention":
633
+ original_frames, attention_maps = compute_attention_all_frames(video_frames, model, processor)
634
+ visualization_cache[video_path_str][mode] = {}
635
+ for i in range(num_frames):
636
+ # Map stored frame to temporal patch
637
+ temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1)
638
+ model_frame_idx = min(i, len(original_frames) - 1)
639
+ if temporal_patch_idx < len(attention_maps):
640
+ fig = plt.figure(figsize=(6, 6))
641
+ ax = plt.subplot(111)
642
+ ax.imshow(original_frames[model_frame_idx])
643
+ ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet')
644
+ ax.set_title(f"Attention Map - Frame {i}")
645
+ ax.axis('off')
646
+ visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
647
+
648
+ elif mode == "latent":
649
+ original_frames, pca_images = compute_latent_all_frames(video_frames, model, processor)
650
+ visualization_cache[video_path_str][mode] = {}
651
+ for i in range(num_frames):
652
+ # Map stored frame to temporal patch
653
+ temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1)
654
+ if temporal_patch_idx < len(pca_images):
655
+ fig = plt.figure(figsize=(6, 6))
656
+ ax = plt.subplot(111)
657
+ ax.imshow(pca_images[temporal_patch_idx])
658
+ ax.set_title(f"PCA Components - Frame {i}")
659
+ ax.axis('off')
660
+ visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
661
+
662
+ print(f"Caching complete for {mode} mode")
663
+
664
+ # Load from cache
665
+ max_idx = len(stored_frames) - 1
666
+ frame_idx = 0
667
+
668
+ # Get original frame
669
+ if isinstance(stored_frames[0], Image.Image):
670
+ first_frame = np.array(stored_frames[0])
671
+ else:
672
+ first_frame = stored_frames[0]
673
+
674
+ # Get visualization from cache
675
+ if video_path_str in visualization_cache and mode in visualization_cache[video_path_str]:
676
+ if frame_idx in visualization_cache[video_path_str][mode]:
677
+ viz_img = visualization_cache[video_path_str][mode][frame_idx]
678
+ else:
679
+ # Fallback if frame not in cache
680
+ viz_img = Image.fromarray(first_frame)
681
+ else:
682
+ # Fallback if not cached
683
+ viz_img = Image.fromarray(first_frame)
684
+
685
+ return gr.update(maximum=max_idx, value=0), first_frame, viz_img
686
+
687
+ def update_frame(idx, mode):
688
+ global stored_frames, visualization_cache, current_video_path
689
+ if not stored_frames:
690
+ return None, None
691
+
692
+ frame_idx = int(idx)
693
+ if frame_idx >= len(stored_frames):
694
+ frame_idx = len(stored_frames) - 1
695
+
696
+ # Get frame
697
+ if isinstance(stored_frames[frame_idx], Image.Image):
698
+ frame = np.array(stored_frames[frame_idx])
699
+ else:
700
+ frame = stored_frames[frame_idx]
701
+
702
+ # Load visualization from cache (fast!)
703
+ video_path_str = current_video_path
704
+ if video_path_str and video_path_str in visualization_cache:
705
+ if mode in visualization_cache[video_path_str]:
706
+ if frame_idx in visualization_cache[video_path_str][mode]:
707
+ viz_img = visualization_cache[video_path_str][mode][frame_idx]
708
+ else:
709
+ # Fallback if frame not in cache
710
+ viz_img = Image.fromarray(frame)
711
+ else:
712
+ # Mode not cached, return frame
713
+ viz_img = Image.fromarray(frame)
714
+ else:
715
+ # Not cached, return frame
716
+ viz_img = Image.fromarray(frame)
717
+
718
+ return frame, viz_img
719
+
720
+
721
+ def load_example_video(video_file):
722
+ def _load_example_video( mode):
723
+ """Load the predefined example video"""
724
+ example_path = f"examples/{video_file}"
725
+ return on_upload(example_path, mode)
726
+
727
+ return _load_example_video
728
+
729
+ # --- Gradio UI Layout ---
730
+ with gr.Blocks(title="VideoMAE Representation Explorer") as demo:
731
+ gr.Markdown("## 🎥 VideoMAE Frame-by-Frame Representation Explorer")
732
+
733
+ mode_radio = gr.Radio(
734
+ choices=["reconstruction", "attention", "latent"],
735
+ value="reconstruction",
736
+ label="Visualization Mode",
737
+ info="Choose the type of visualization"
738
+ )
739
+
740
+ with gr.Row():
741
+ with gr.Column():
742
+ orig_output = gr.Image(label="Original Frame")
743
+ with gr.Column():
744
+ viz_output = gr.Image(label="Representation / Attention")
745
+
746
+ frame_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Frame Index")
747
+
748
+ # Event Listeners
749
+ video_lists = os.listdir("examples")
750
+ with gr.Row():
751
+ video_input = gr.Video(label="Upload Video")
752
+ with gr.Column():
753
+ for video_file in video_lists:
754
+ load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary")
755
+ load_example_btn.click(load_example_video(video_file), inputs=mode_radio, outputs=[frame_slider, orig_output, viz_output])
756
+
757
+ # load_example_btn = gr.Button("Load Example Video (dog.mp4)", variant="secondary")
758
+ video_input.change(on_upload, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
759
+
760
+ frame_slider.change(update_frame, inputs=[frame_slider, mode_radio], outputs=[orig_output, viz_output])
761
+ def on_mode_change(video_path, mode):
762
+ """Handle mode change - compute if not cached, otherwise use cache"""
763
+ global stored_frames, model, processor, visualization_cache, current_video_path
764
+ if video_path is None:
765
+ return gr.update(maximum=0), None, None
766
+
767
+ video_path_str = str(video_path)
768
+
769
+ # If video is already loaded and mode is cached, just return cached result
770
+ if video_path_str == current_video_path and video_path_str in visualization_cache:
771
+ if mode in visualization_cache[video_path_str]:
772
+ max_idx = len(stored_frames) - 1
773
+ frame_idx = 0
774
+ if isinstance(stored_frames[0], Image.Image):
775
+ first_frame = np.array(stored_frames[0])
776
+ else:
777
+ first_frame = stored_frames[0]
778
+ if frame_idx in visualization_cache[video_path_str][mode]:
779
+ viz_img = visualization_cache[video_path_str][mode][frame_idx]
780
+ else:
781
+ viz_img = Image.fromarray(first_frame)
782
+ return gr.update(maximum=max_idx, value=0), first_frame, viz_img
783
+
784
+ # Otherwise, compute (will use cached video frames if available)
785
+ return on_upload(video_path, mode)
786
+
787
+ mode_radio.change(on_mode_change, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
788
+
789
+ if __name__ == "__main__":
790
+ # Initialize model at startup
791
+ initialize_model()
792
+ demo.launch()