Spaces:
Runtime error
Runtime error
| import os | |
| # Import spaces before torch to avoid CUDA initialization error | |
| try: | |
| import spaces | |
| except ImportError: | |
| pass | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from pathlib import Path | |
| from sklearn.decomposition import PCA | |
| from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel | |
| from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import matplotlib.pyplot as plt | |
| import io | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Helper function to convert matplotlib figure to PIL Image | |
| def fig_to_image(fig): | |
| """Convert matplotlib figure to PIL Image""" | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close(fig) | |
| return img | |
| def load_video(video_path, num_frames=16, sample_rate=4): | |
| """ | |
| Load video from file path. | |
| Returns list of PIL Images or numpy arrays. | |
| """ | |
| video_path = Path(video_path) | |
| if not video_path.exists(): | |
| raise FileNotFoundError(f"Video file not found: {video_path}") | |
| # Try to load as video file | |
| cap = cv2.VideoCapture(str(video_path)) | |
| frames = [] | |
| if not cap.isOpened(): | |
| raise ValueError(f"Could not open video file: {video_path}") | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if sample_rate * num_frames > frame_count: | |
| frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int) | |
| print(f"warning: only {num_frames} frames are sampled from {frame_count} frames") | |
| else: | |
| frame_indices = np.arange(0, sample_rate * num_frames, sample_rate) | |
| print(f"Sampling {frame_indices}") | |
| for idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Convert BGR to RGB | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame)) | |
| print(f"Loaded {len(frames)} frames") | |
| cap.release() | |
| return frames | |
| def load_model(model_name, model_type='pretraining'): | |
| """ | |
| Load model and processor by name. | |
| model_type: 'pretraining' for VideoMAEForPreTraining, 'base' for VideoMAEModel | |
| """ | |
| processor = VideoMAEImageProcessor.from_pretrained(model_name) | |
| if model_type == 'base': | |
| model = VideoMAEModel.from_pretrained(model_name) | |
| else: | |
| model = VideoMAEForPreTraining.from_pretrained(model_name) | |
| model = model.to(device) | |
| return model, processor | |
| # Global model and processor | |
| model = None | |
| processor = None | |
| def initialize_model(model_name='MCG-NJU/videomae-base'): | |
| """Initialize the model (call once at startup)""" | |
| global model, processor | |
| if model is None: | |
| print(f"Loading model: {model_name}") | |
| model, processor = load_model(model_name) | |
| print("Model loaded successfully") | |
| return model, processor | |
| def visualize_attention(video_frames, model, processor, layer_idx=-1): | |
| """ | |
| Visualize attention maps from VideoMAE model. | |
| Returns PIL Image for Gradio. | |
| """ | |
| inputs = processor(video_frames, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(device) | |
| batch_size, time, num_channels, height, width = pixel_values.shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| num_patches_per_frame = (height // patch_size) * (width // patch_size) | |
| num_temporal_patches = time // tubelet_size | |
| # Use VideoMAEModel to get attention weights | |
| if hasattr(model, 'videomae'): | |
| encoder_model = model.videomae | |
| else: | |
| encoder_model = model | |
| # Disable SDPA and use eager attention | |
| original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None) | |
| encoder_model.config._attn_implementation = "eager" | |
| try: | |
| outputs = encoder_model(pixel_values, output_attentions=True) | |
| finally: | |
| if original_attn_impl is not None: | |
| encoder_model.config._attn_implementation = original_attn_impl | |
| attentions = outputs.attentions | |
| if layer_idx < 0: | |
| layer_idx = len(attentions) + layer_idx | |
| attention_weights = attentions[layer_idx][0] | |
| avg_attn = attention_weights.mean(dim=0) | |
| # Unnormalize frames | |
| dtype = pixel_values.dtype | |
| mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| frames_unnorm = pixel_values * std + mean | |
| frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| frames_unnorm = np.clip(frames_unnorm, 0, 1) | |
| seq_len = avg_attn.shape[0] | |
| H_p = height // patch_size | |
| W_p = width // patch_size | |
| expected_seq_len = num_temporal_patches * num_patches_per_frame | |
| if seq_len != expected_seq_len: | |
| if seq_len % num_patches_per_frame == 0: | |
| num_temporal_patches = seq_len // num_patches_per_frame | |
| else: | |
| raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}") | |
| avg_attn_received = avg_attn.mean(dim=0) | |
| attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p) | |
| # Create visualization for first frame | |
| frame_idx = 0 | |
| frame_img = frames_unnorm[frame_idx * tubelet_size] | |
| attn_map = attn_per_patch[frame_idx].detach().cpu().numpy() | |
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) | |
| attn_map_upsampled = cv2.resize(attn_map, (width, height)) | |
| # Create overlay | |
| fig, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
| ax.imshow(frame_img) | |
| ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet') | |
| ax.set_title(f"Attention Map - Frame {frame_idx * tubelet_size}") | |
| ax.axis('off') | |
| return fig_to_image(fig) | |
| def visualize_latent(video_frames, model, processor): | |
| """ | |
| Visualize latent space representations from VideoMAE model. | |
| Returns PIL Image for Gradio. | |
| """ | |
| inputs = processor(video_frames, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(device) | |
| if hasattr(model, 'videomae'): | |
| encoder_model = model.videomae | |
| else: | |
| encoder_model = model | |
| outputs = encoder_model(pixel_values, output_hidden_states=True) | |
| hidden_states = outputs.last_hidden_state[0] | |
| batch_size, time, num_channels, height, width = pixel_values.shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| num_patches_per_frame = (height // patch_size) * (width // patch_size) | |
| num_temporal_patches = time // tubelet_size | |
| # Unnormalize frames | |
| dtype = pixel_values.dtype | |
| mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| frames_unnorm = pixel_values * std + mean | |
| frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| frames_unnorm = np.clip(frames_unnorm, 0, 1) | |
| seq_len = hidden_states.shape[0] | |
| expected_seq_len = num_temporal_patches * num_patches_per_frame | |
| if seq_len != expected_seq_len: | |
| if seq_len % num_patches_per_frame == 0: | |
| num_temporal_patches = seq_len // num_patches_per_frame | |
| else: | |
| raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}") | |
| hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1) | |
| hidden_size = hidden_states_reshaped.shape[-1] | |
| hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy() | |
| pca = PCA(n_components=3) | |
| pca_components = pca.fit_transform(hidden_states_flat) | |
| pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3) | |
| H_p = int(np.sqrt(num_patches_per_frame)) | |
| W_p = H_p | |
| if H_p * W_p == num_patches_per_frame: | |
| pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) | |
| else: | |
| factors = [] | |
| for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1): | |
| if num_patches_per_frame % i == 0: | |
| factors.append((i, num_patches_per_frame // i)) | |
| if factors: | |
| H_p, W_p = factors[-1] | |
| pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) | |
| else: | |
| raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid") | |
| # Normalize components | |
| for t in range(num_temporal_patches): | |
| for c in range(3): | |
| comp = pca_spatial[t, :, :, c] | |
| comp_min, comp_max = comp.min(), comp.max() | |
| if comp_max > comp_min: | |
| pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min) | |
| else: | |
| pca_spatial[t, :, :, c] = 0.5 | |
| # Show first frame | |
| frame_idx = 0 | |
| frame_img = frames_unnorm[frame_idx * tubelet_size] | |
| rgb_image = pca_spatial[frame_idx] | |
| upscale_factor = 8 | |
| rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST) | |
| fig = plt.figure(figsize=(6,6)) | |
| ax = fig.add_subplot(1, 1, 1) | |
| ax.imshow(rgb_image_upscaled) | |
| ax.set_title(f"PCA Components (RGB = PC1, PC2, PC3)") | |
| ax.axis('off') | |
| plt.suptitle(f"Explained Variance: {pca.explained_variance_ratio_.sum():.2%}", fontsize=12) | |
| plt.tight_layout() | |
| return fig_to_image(fig) | |
| def compute_reconstruction_all_frames(video_frames, model, processor): | |
| """ | |
| Compute reconstruction for all frames and return as numpy arrays. | |
| Returns: (original_frames, reconstructed_frames) as numpy arrays | |
| """ | |
| inputs = processor(video_frames, return_tensors="pt") | |
| T, C, H, W = inputs['pixel_values'][0].shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| T = T//tubelet_size | |
| num_patches = (model.config.image_size // model.config.patch_size) ** 2 | |
| num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size)) | |
| total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches | |
| batch_size = inputs['pixel_values'].shape[0] | |
| bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool) | |
| for b in range(batch_size): | |
| mask_indices = np.random.choice(total_patches, num_masked, replace=False) | |
| bool_masked_pos[b, mask_indices] = True | |
| inputs['bool_masked_pos'] = bool_masked_pos.to(device) | |
| inputs['pixel_values'] = inputs['pixel_values'].to(device) | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| pixel_values = inputs['pixel_values'] | |
| batch_size, time, num_channels, height, width = pixel_values.shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| num_patches_per_frame = (height // patch_size) * (width // patch_size) | |
| num_temporal_patches = time // tubelet_size | |
| total_patches = num_temporal_patches * num_patches_per_frame | |
| dtype = pixel_values.dtype | |
| mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| frames_unnorm = pixel_values * std + mean | |
| frames_patched = frames_unnorm.view( | |
| batch_size, time // tubelet_size, tubelet_size, num_channels, | |
| height // patch_size, patch_size, width // patch_size, patch_size, | |
| ) | |
| frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous() | |
| videos_patch = frames_patched.view( | |
| batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels, | |
| ) | |
| if model.config.norm_pix_loss: | |
| patch_mean = videos_patch.mean(dim=-2, keepdim=True) | |
| patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) | |
| logits_denorm = logits * patch_std + patch_mean | |
| else: | |
| logits_denorm = torch.clamp(logits, 0.0, 1.0) | |
| reconstructed_patches = videos_patch.clone() | |
| reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels) | |
| reconstructed_patches_reshaped = reconstructed_patches.view( | |
| batch_size, time // tubelet_size, height // patch_size, width // patch_size, | |
| tubelet_size, patch_size, patch_size, num_channels, | |
| ) | |
| reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous() | |
| reconstructed_frames = reconstructed_patches_reshaped.view( | |
| batch_size, time, num_channels, height, width, | |
| ) | |
| original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| original_frames = np.clip(original_frames, 0, 1) | |
| reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1) | |
| return original_frames, reconstructed_frames_np | |
| def visualize_reconstruction(video_frames, model, processor): | |
| """ | |
| Visualize reconstruction from VideoMAE model. | |
| Returns PIL Image for Gradio. | |
| """ | |
| inputs = processor(video_frames, return_tensors="pt") | |
| T, C, H, W = inputs['pixel_values'][0].shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| T = T//tubelet_size | |
| num_patches = (model.config.image_size // model.config.patch_size) ** 2 | |
| num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size)) | |
| total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches | |
| batch_size = inputs['pixel_values'].shape[0] | |
| bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool) | |
| for b in range(batch_size): | |
| mask_indices = np.random.choice(total_patches, num_masked, replace=False) | |
| bool_masked_pos[b, mask_indices] = True | |
| inputs['bool_masked_pos'] = bool_masked_pos.to(device) | |
| inputs['pixel_values'] = inputs['pixel_values'].to(device) | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| pixel_values = inputs['pixel_values'] | |
| batch_size, time, num_channels, height, width = pixel_values.shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| num_patches_per_frame = (height // patch_size) * (width // patch_size) | |
| num_temporal_patches = time // tubelet_size | |
| total_patches = num_temporal_patches * num_patches_per_frame | |
| dtype = pixel_values.dtype | |
| mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| frames_unnorm = pixel_values * std + mean | |
| frames_patched = frames_unnorm.view( | |
| batch_size, time // tubelet_size, tubelet_size, num_channels, | |
| height // patch_size, patch_size, width // patch_size, patch_size, | |
| ) | |
| frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous() | |
| videos_patch = frames_patched.view( | |
| batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels, | |
| ) | |
| if model.config.norm_pix_loss: | |
| patch_mean = videos_patch.mean(dim=-2, keepdim=True) | |
| patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) | |
| logits_denorm = logits * patch_std + patch_mean | |
| else: | |
| logits_denorm = torch.clamp(logits, 0.0, 1.0) | |
| reconstructed_patches = videos_patch.clone() | |
| reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels) | |
| reconstructed_patches_reshaped = reconstructed_patches.view( | |
| batch_size, time // tubelet_size, height // patch_size, width // patch_size, | |
| tubelet_size, patch_size, patch_size, num_channels, | |
| ) | |
| reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous() | |
| reconstructed_frames = reconstructed_patches_reshaped.view( | |
| batch_size, time, num_channels, height, width, | |
| ) | |
| original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| original_frames = np.clip(original_frames, 0, 1) | |
| reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1) | |
| # Show first frame | |
| frame_idx = 0 | |
| fig = plt.figure(figsize=(6,6)) | |
| ax = plt.subplot(111) | |
| ax.imshow(reconstructed_frames_np[frame_idx * tubelet_size]) | |
| ax.set_title(f"Reconstructed Frame: {frame_idx * tubelet_size}") | |
| ax.axis('off') | |
| return fig_to_image(fig) | |
| def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1): | |
| """ | |
| Compute attention maps for all frames. | |
| Returns: (original_frames, attention_maps) as numpy arrays | |
| """ | |
| inputs = processor(video_frames, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(device) | |
| batch_size, time, num_channels, height, width = pixel_values.shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| num_patches_per_frame = (height // patch_size) * (width // patch_size) | |
| num_temporal_patches = time // tubelet_size | |
| if hasattr(model, 'videomae'): | |
| encoder_model = model.videomae | |
| else: | |
| encoder_model = model | |
| original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None) | |
| encoder_model.config._attn_implementation = "eager" | |
| try: | |
| outputs = encoder_model(pixel_values, output_attentions=True) | |
| finally: | |
| if original_attn_impl is not None: | |
| encoder_model.config._attn_implementation = original_attn_impl | |
| attentions = outputs.attentions | |
| if layer_idx < 0: | |
| layer_idx = len(attentions) + layer_idx | |
| attention_weights = attentions[layer_idx][0] | |
| avg_attn = attention_weights.mean(dim=0) | |
| dtype = pixel_values.dtype | |
| mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| frames_unnorm = pixel_values * std + mean | |
| frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| frames_unnorm = np.clip(frames_unnorm, 0, 1) | |
| seq_len = avg_attn.shape[0] | |
| H_p = height // patch_size | |
| W_p = width // patch_size | |
| expected_seq_len = num_temporal_patches * num_patches_per_frame | |
| if seq_len != expected_seq_len: | |
| if seq_len % num_patches_per_frame == 0: | |
| num_temporal_patches = seq_len // num_patches_per_frame | |
| else: | |
| raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}") | |
| avg_attn_received = avg_attn.mean(dim=0) | |
| attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p) | |
| # Create attention maps for all temporal patches | |
| attention_maps = [] | |
| for frame_idx in range(num_temporal_patches): | |
| attn_map = attn_per_patch[frame_idx].detach().cpu().numpy() | |
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) | |
| attn_map_upsampled = cv2.resize(attn_map, (width, height)) | |
| attention_maps.append(attn_map_upsampled) | |
| return frames_unnorm, attention_maps | |
| def compute_latent_all_frames(video_frames, model, processor): | |
| """ | |
| Compute PCA latent visualizations for all frames. | |
| Returns: (original_frames, pca_images) as numpy arrays | |
| """ | |
| inputs = processor(video_frames, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(device) | |
| if hasattr(model, 'videomae'): | |
| encoder_model = model.videomae | |
| else: | |
| encoder_model = model | |
| outputs = encoder_model(pixel_values, output_hidden_states=True) | |
| hidden_states = outputs.last_hidden_state[0] | |
| batch_size, time, num_channels, height, width = pixel_values.shape | |
| tubelet_size = model.config.tubelet_size | |
| patch_size = model.config.patch_size | |
| num_patches_per_frame = (height // patch_size) * (width // patch_size) | |
| num_temporal_patches = time // tubelet_size | |
| dtype = pixel_values.dtype | |
| mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] | |
| frames_unnorm = pixel_values * std + mean | |
| frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() | |
| frames_unnorm = np.clip(frames_unnorm, 0, 1) | |
| seq_len = hidden_states.shape[0] | |
| expected_seq_len = num_temporal_patches * num_patches_per_frame | |
| if seq_len != expected_seq_len: | |
| if seq_len % num_patches_per_frame == 0: | |
| num_temporal_patches = seq_len // num_patches_per_frame | |
| else: | |
| raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}") | |
| hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1) | |
| hidden_size = hidden_states_reshaped.shape[-1] | |
| hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy() | |
| pca = PCA(n_components=3) | |
| pca_components = pca.fit_transform(hidden_states_flat) | |
| pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3) | |
| H_p = int(np.sqrt(num_patches_per_frame)) | |
| W_p = H_p | |
| if H_p * W_p == num_patches_per_frame: | |
| pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) | |
| else: | |
| factors = [] | |
| for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1): | |
| if num_patches_per_frame % i == 0: | |
| factors.append((i, num_patches_per_frame // i)) | |
| if factors: | |
| H_p, W_p = factors[-1] | |
| pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) | |
| else: | |
| raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid") | |
| # Normalize components | |
| for t in range(num_temporal_patches): | |
| for c in range(3): | |
| comp = pca_spatial[t, :, :, c] | |
| comp_min, comp_max = comp.min(), comp.max() | |
| if comp_max > comp_min: | |
| pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min) | |
| else: | |
| pca_spatial[t, :, :, c] = 0.5 | |
| # Create upscaled images for all frames | |
| upscale_factor = 8 | |
| pca_images = [] | |
| for t_idx in range(num_temporal_patches): | |
| rgb_image = pca_spatial[t_idx] | |
| rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST) | |
| pca_images.append(rgb_image_upscaled) | |
| return frames_unnorm, pca_images | |
| # Dummy function for backward compatibility | |
| def process_video(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame) | |
| cap.release() | |
| visualizations = [cv2.applyColorMap((f * 0.5).astype(np.uint8), cv2.COLORMAP_JET) for f in frames] | |
| return frames, visualizations | |
| # Global state to store frames after upload | |
| stored_frames = [] | |
| stored_viz = [] | |
| # Cache for visualization results: {video_path: {mode: {frame_idx: image}}} | |
| visualization_cache = {} | |
| current_video_path = None | |
| def on_upload(video_path, mode): | |
| global stored_frames, stored_viz, model, processor, visualization_cache, current_video_path | |
| if video_path is None: | |
| return gr.update(maximum=0), None, None | |
| # Initialize model if needed | |
| if model is None: | |
| model, processor = initialize_model() | |
| # Check if we need to recompute (new video or mode not cached) | |
| video_path_str = str(video_path) | |
| need_to_load_video = (video_path_str != current_video_path) | |
| need_to_compute_mode = (video_path_str not in visualization_cache or mode not in visualization_cache[video_path_str]) | |
| if need_to_load_video: | |
| # Load video frames | |
| print(f"Loading video: {video_path_str}") | |
| video_frames = load_video(video_path) | |
| stored_frames = video_frames | |
| current_video_path = video_path_str | |
| else: | |
| # Reuse already loaded frames | |
| video_frames = stored_frames | |
| # Initialize cache for this video | |
| if video_path_str not in visualization_cache: | |
| visualization_cache[video_path_str] = {} | |
| if need_to_compute_mode: | |
| # Compute all visualizations and cache them | |
| print(f"Computing {mode} visualization for all frames...") | |
| num_frames = len(stored_frames) | |
| tubelet_size = model.config.tubelet_size | |
| if mode == "reconstruction": | |
| original_frames, reconstructed_frames = compute_reconstruction_all_frames(video_frames, model, processor) | |
| # Cache as images per frame - map model frames to stored frames | |
| visualization_cache[video_path_str][mode] = {} | |
| for i in range(num_frames): | |
| # Map stored frame index to model frame index | |
| model_frame_idx = min(i, len(reconstructed_frames) - 1) | |
| fig = plt.figure(figsize=(6, 6)) | |
| ax = plt.subplot(111) | |
| ax.imshow(reconstructed_frames[model_frame_idx]) | |
| ax.set_title(f"Reconstructed Frame: {i}") | |
| ax.axis('off') | |
| visualization_cache[video_path_str][mode][i] = fig_to_image(fig) | |
| elif mode == "attention": | |
| original_frames, attention_maps = compute_attention_all_frames(video_frames, model, processor) | |
| visualization_cache[video_path_str][mode] = {} | |
| for i in range(num_frames): | |
| # Map stored frame to temporal patch | |
| temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1) | |
| model_frame_idx = min(i, len(original_frames) - 1) | |
| if temporal_patch_idx < len(attention_maps): | |
| fig = plt.figure(figsize=(6, 6)) | |
| ax = plt.subplot(111) | |
| ax.imshow(original_frames[model_frame_idx]) | |
| ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet') | |
| ax.set_title(f"Attention Map - Frame {i}") | |
| ax.axis('off') | |
| visualization_cache[video_path_str][mode][i] = fig_to_image(fig) | |
| elif mode == "latent": | |
| original_frames, pca_images = compute_latent_all_frames(video_frames, model, processor) | |
| visualization_cache[video_path_str][mode] = {} | |
| for i in range(num_frames): | |
| # Map stored frame to temporal patch | |
| temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1) | |
| if temporal_patch_idx < len(pca_images): | |
| fig = plt.figure(figsize=(6, 6)) | |
| ax = plt.subplot(111) | |
| ax.imshow(pca_images[temporal_patch_idx]) | |
| ax.set_title(f"PCA Components - Frame {i}") | |
| ax.axis('off') | |
| visualization_cache[video_path_str][mode][i] = fig_to_image(fig) | |
| print(f"Caching complete for {mode} mode") | |
| # Load from cache | |
| max_idx = len(stored_frames) - 1 | |
| frame_idx = 0 | |
| # Get original frame | |
| if isinstance(stored_frames[0], Image.Image): | |
| first_frame = np.array(stored_frames[0]) | |
| else: | |
| first_frame = stored_frames[0] | |
| # Get visualization from cache | |
| if video_path_str in visualization_cache and mode in visualization_cache[video_path_str]: | |
| if frame_idx in visualization_cache[video_path_str][mode]: | |
| viz_img = visualization_cache[video_path_str][mode][frame_idx] | |
| else: | |
| # Fallback if frame not in cache | |
| viz_img = Image.fromarray(first_frame) | |
| else: | |
| # Fallback if not cached | |
| viz_img = Image.fromarray(first_frame) | |
| return gr.update(maximum=max_idx, value=0), first_frame, viz_img | |
| def update_frame(idx, mode): | |
| global stored_frames, visualization_cache, current_video_path | |
| if not stored_frames: | |
| return None, None | |
| frame_idx = int(idx) | |
| if frame_idx >= len(stored_frames): | |
| frame_idx = len(stored_frames) - 1 | |
| # Get frame | |
| if isinstance(stored_frames[frame_idx], Image.Image): | |
| frame = np.array(stored_frames[frame_idx]) | |
| else: | |
| frame = stored_frames[frame_idx] | |
| # Load visualization from cache (fast!) | |
| video_path_str = current_video_path | |
| if video_path_str and video_path_str in visualization_cache: | |
| if mode in visualization_cache[video_path_str]: | |
| if frame_idx in visualization_cache[video_path_str][mode]: | |
| viz_img = visualization_cache[video_path_str][mode][frame_idx] | |
| else: | |
| # Fallback if frame not in cache | |
| viz_img = Image.fromarray(frame) | |
| else: | |
| # Mode not cached, return frame | |
| viz_img = Image.fromarray(frame) | |
| else: | |
| # Not cached, return frame | |
| viz_img = Image.fromarray(frame) | |
| return frame, viz_img | |
| def load_example_video(video_file): | |
| def _load_example_video( mode): | |
| """Load the predefined example video""" | |
| example_path = f"examples/{video_file}" | |
| return on_upload(example_path, mode) | |
| return _load_example_video | |
| # --- Gradio UI Layout --- | |
| with gr.Blocks(title="VideoMAE Representation Explorer") as demo: | |
| gr.Markdown("## 🎥 VideoMAE Frame-by-Frame Representation Explorer") | |
| mode_radio = gr.Radio( | |
| choices=["reconstruction", "attention", "latent"], | |
| value="reconstruction", | |
| label="Visualization Mode", | |
| info="Choose the type of visualization" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| orig_output = gr.Image(label="Original Frame") | |
| with gr.Column(): | |
| viz_output = gr.Image(label="Representation / Attention") | |
| frame_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Frame Index") | |
| # Event Listeners | |
| video_lists = os.listdir("examples") if os.path.exists("examples") else [] | |
| video_lists = os.listdir("app/examples") if os.path.exists("app/examples") else video_lists | |
| with gr.Row(): | |
| video_input = gr.Video(label="Upload Video") | |
| with gr.Column(): | |
| for video_file in video_lists: | |
| load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary") | |
| load_example_btn.click(load_example_video(video_file), inputs=mode_radio, outputs=[frame_slider, orig_output, viz_output]) | |
| # load_example_btn = gr.Button("Load Example Video (dog.mp4)", variant="secondary") | |
| video_input.change(on_upload, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output]) | |
| frame_slider.change(update_frame, inputs=[frame_slider, mode_radio], outputs=[orig_output, viz_output]) | |
| def on_mode_change(video_path, mode): | |
| """Handle mode change - compute if not cached, otherwise use cache""" | |
| global stored_frames, model, processor, visualization_cache, current_video_path | |
| if video_path is None: | |
| return gr.update(maximum=0), None, None | |
| video_path_str = str(video_path) | |
| # If video is already loaded and mode is cached, just return cached result | |
| if video_path_str == current_video_path and video_path_str in visualization_cache: | |
| if mode in visualization_cache[video_path_str]: | |
| max_idx = len(stored_frames) - 1 | |
| frame_idx = 0 | |
| if isinstance(stored_frames[0], Image.Image): | |
| first_frame = np.array(stored_frames[0]) | |
| else: | |
| first_frame = stored_frames[0] | |
| if frame_idx in visualization_cache[video_path_str][mode]: | |
| viz_img = visualization_cache[video_path_str][mode][frame_idx] | |
| else: | |
| viz_img = Image.fromarray(first_frame) | |
| return gr.update(maximum=max_idx, value=0), first_frame, viz_img | |
| # Otherwise, compute (will use cached video frames if available) | |
| return on_upload(video_path, mode) | |
| mode_radio.change(on_mode_change, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output]) | |
| if __name__ == "__main__": | |
| # Initialize model at startup | |
| initialize_model() | |
| demo.launch() |