import os # Import spaces before torch to avoid CUDA initialization error try: import spaces except ImportError: pass import gradio as gr import cv2 import numpy as np import torch from PIL import Image from pathlib import Path from sklearn.decomposition import PCA from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt import io device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Helper function to convert matplotlib figure to PIL Image def fig_to_image(fig): """Convert matplotlib figure to PIL Image""" buf = io.BytesIO() fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) img = Image.open(buf) plt.close(fig) return img def load_video(video_path, num_frames=16, sample_rate=4): """ Load video from file path. Returns list of PIL Images or numpy arrays. """ video_path = Path(video_path) if not video_path.exists(): raise FileNotFoundError(f"Video file not found: {video_path}") # Try to load as video file cap = cv2.VideoCapture(str(video_path)) frames = [] if not cap.isOpened(): raise ValueError(f"Could not open video file: {video_path}") frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if sample_rate * num_frames > frame_count: frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int) print(f"warning: only {num_frames} frames are sampled from {frame_count} frames") else: frame_indices = np.arange(0, sample_rate * num_frames, sample_rate) print(f"Sampling {frame_indices}") for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: # Convert BGR to RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) print(f"Loaded {len(frames)} frames") cap.release() return frames def load_model(model_name, model_type='pretraining'): """ Load model and processor by name. model_type: 'pretraining' for VideoMAEForPreTraining, 'base' for VideoMAEModel """ processor = VideoMAEImageProcessor.from_pretrained(model_name) if model_type == 'base': model = VideoMAEModel.from_pretrained(model_name) else: model = VideoMAEForPreTraining.from_pretrained(model_name) model = model.to(device) return model, processor # Global model and processor model = None processor = None def initialize_model(model_name='MCG-NJU/videomae-base'): """Initialize the model (call once at startup)""" global model, processor if model is None: print(f"Loading model: {model_name}") model, processor = load_model(model_name) print("Model loaded successfully") return model, processor def visualize_attention(video_frames, model, processor, layer_idx=-1): """ Visualize attention maps from VideoMAE model. Returns PIL Image for Gradio. """ inputs = processor(video_frames, return_tensors="pt") pixel_values = inputs['pixel_values'].to(device) batch_size, time, num_channels, height, width = pixel_values.shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size num_patches_per_frame = (height // patch_size) * (width // patch_size) num_temporal_patches = time // tubelet_size # Use VideoMAEModel to get attention weights if hasattr(model, 'videomae'): encoder_model = model.videomae else: encoder_model = model # Disable SDPA and use eager attention original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None) encoder_model.config._attn_implementation = "eager" try: outputs = encoder_model(pixel_values, output_attentions=True) finally: if original_attn_impl is not None: encoder_model.config._attn_implementation = original_attn_impl attentions = outputs.attentions if layer_idx < 0: layer_idx = len(attentions) + layer_idx attention_weights = attentions[layer_idx][0] avg_attn = attention_weights.mean(dim=0) # Unnormalize frames dtype = pixel_values.dtype mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] frames_unnorm = pixel_values * std + mean frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() frames_unnorm = np.clip(frames_unnorm, 0, 1) seq_len = avg_attn.shape[0] H_p = height // patch_size W_p = width // patch_size expected_seq_len = num_temporal_patches * num_patches_per_frame if seq_len != expected_seq_len: if seq_len % num_patches_per_frame == 0: num_temporal_patches = seq_len // num_patches_per_frame else: raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}") avg_attn_received = avg_attn.mean(dim=0) attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p) # Create visualization for first frame frame_idx = 0 frame_img = frames_unnorm[frame_idx * tubelet_size] attn_map = attn_per_patch[frame_idx].detach().cpu().numpy() attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) attn_map_upsampled = cv2.resize(attn_map, (width, height)) # Create overlay fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(frame_img) ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet') ax.set_title(f"Attention Map - Frame {frame_idx * tubelet_size}") ax.axis('off') return fig_to_image(fig) def visualize_latent(video_frames, model, processor): """ Visualize latent space representations from VideoMAE model. Returns PIL Image for Gradio. """ inputs = processor(video_frames, return_tensors="pt") pixel_values = inputs['pixel_values'].to(device) if hasattr(model, 'videomae'): encoder_model = model.videomae else: encoder_model = model outputs = encoder_model(pixel_values, output_hidden_states=True) hidden_states = outputs.last_hidden_state[0] batch_size, time, num_channels, height, width = pixel_values.shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size num_patches_per_frame = (height // patch_size) * (width // patch_size) num_temporal_patches = time // tubelet_size # Unnormalize frames dtype = pixel_values.dtype mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] frames_unnorm = pixel_values * std + mean frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() frames_unnorm = np.clip(frames_unnorm, 0, 1) seq_len = hidden_states.shape[0] expected_seq_len = num_temporal_patches * num_patches_per_frame if seq_len != expected_seq_len: if seq_len % num_patches_per_frame == 0: num_temporal_patches = seq_len // num_patches_per_frame else: raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}") hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1) hidden_size = hidden_states_reshaped.shape[-1] hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy() pca = PCA(n_components=3) pca_components = pca.fit_transform(hidden_states_flat) pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3) H_p = int(np.sqrt(num_patches_per_frame)) W_p = H_p if H_p * W_p == num_patches_per_frame: pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) else: factors = [] for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1): if num_patches_per_frame % i == 0: factors.append((i, num_patches_per_frame // i)) if factors: H_p, W_p = factors[-1] pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) else: raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid") # Normalize components for t in range(num_temporal_patches): for c in range(3): comp = pca_spatial[t, :, :, c] comp_min, comp_max = comp.min(), comp.max() if comp_max > comp_min: pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min) else: pca_spatial[t, :, :, c] = 0.5 # Show first frame frame_idx = 0 frame_img = frames_unnorm[frame_idx * tubelet_size] rgb_image = pca_spatial[frame_idx] upscale_factor = 8 rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST) fig = plt.figure(figsize=(6,6)) ax = fig.add_subplot(1, 1, 1) ax.imshow(rgb_image_upscaled) ax.set_title(f"PCA Components (RGB = PC1, PC2, PC3)") ax.axis('off') plt.suptitle(f"Explained Variance: {pca.explained_variance_ratio_.sum():.2%}", fontsize=12) plt.tight_layout() return fig_to_image(fig) def compute_reconstruction_all_frames(video_frames, model, processor): """ Compute reconstruction for all frames and return as numpy arrays. Returns: (original_frames, reconstructed_frames) as numpy arrays """ inputs = processor(video_frames, return_tensors="pt") T, C, H, W = inputs['pixel_values'][0].shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size T = T//tubelet_size num_patches = (model.config.image_size // model.config.patch_size) ** 2 num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size)) total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches batch_size = inputs['pixel_values'].shape[0] bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool) for b in range(batch_size): mask_indices = np.random.choice(total_patches, num_masked, replace=False) bool_masked_pos[b, mask_indices] = True inputs['bool_masked_pos'] = bool_masked_pos.to(device) inputs['pixel_values'] = inputs['pixel_values'].to(device) outputs = model(**inputs) logits = outputs.logits pixel_values = inputs['pixel_values'] batch_size, time, num_channels, height, width = pixel_values.shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size num_patches_per_frame = (height // patch_size) * (width // patch_size) num_temporal_patches = time // tubelet_size total_patches = num_temporal_patches * num_patches_per_frame dtype = pixel_values.dtype mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] frames_unnorm = pixel_values * std + mean frames_patched = frames_unnorm.view( batch_size, time // tubelet_size, tubelet_size, num_channels, height // patch_size, patch_size, width // patch_size, patch_size, ) frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous() videos_patch = frames_patched.view( batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels, ) if model.config.norm_pix_loss: patch_mean = videos_patch.mean(dim=-2, keepdim=True) patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) logits_denorm = logits * patch_std + patch_mean else: logits_denorm = torch.clamp(logits, 0.0, 1.0) reconstructed_patches = videos_patch.clone() reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels) reconstructed_patches_reshaped = reconstructed_patches.view( batch_size, time // tubelet_size, height // patch_size, width // patch_size, tubelet_size, patch_size, patch_size, num_channels, ) reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous() reconstructed_frames = reconstructed_patches_reshaped.view( batch_size, time, num_channels, height, width, ) original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy() original_frames = np.clip(original_frames, 0, 1) reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1) return original_frames, reconstructed_frames_np def visualize_reconstruction(video_frames, model, processor): """ Visualize reconstruction from VideoMAE model. Returns PIL Image for Gradio. """ inputs = processor(video_frames, return_tensors="pt") T, C, H, W = inputs['pixel_values'][0].shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size T = T//tubelet_size num_patches = (model.config.image_size // model.config.patch_size) ** 2 num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size)) total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches batch_size = inputs['pixel_values'].shape[0] bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool) for b in range(batch_size): mask_indices = np.random.choice(total_patches, num_masked, replace=False) bool_masked_pos[b, mask_indices] = True inputs['bool_masked_pos'] = bool_masked_pos.to(device) inputs['pixel_values'] = inputs['pixel_values'].to(device) outputs = model(**inputs) logits = outputs.logits pixel_values = inputs['pixel_values'] batch_size, time, num_channels, height, width = pixel_values.shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size num_patches_per_frame = (height // patch_size) * (width // patch_size) num_temporal_patches = time // tubelet_size total_patches = num_temporal_patches * num_patches_per_frame dtype = pixel_values.dtype mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] frames_unnorm = pixel_values * std + mean frames_patched = frames_unnorm.view( batch_size, time // tubelet_size, tubelet_size, num_channels, height // patch_size, patch_size, width // patch_size, patch_size, ) frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous() videos_patch = frames_patched.view( batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels, ) if model.config.norm_pix_loss: patch_mean = videos_patch.mean(dim=-2, keepdim=True) patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) logits_denorm = logits * patch_std + patch_mean else: logits_denorm = torch.clamp(logits, 0.0, 1.0) reconstructed_patches = videos_patch.clone() reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels) reconstructed_patches_reshaped = reconstructed_patches.view( batch_size, time // tubelet_size, height // patch_size, width // patch_size, tubelet_size, patch_size, patch_size, num_channels, ) reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous() reconstructed_frames = reconstructed_patches_reshaped.view( batch_size, time, num_channels, height, width, ) original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy() original_frames = np.clip(original_frames, 0, 1) reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1) # Show first frame frame_idx = 0 fig = plt.figure(figsize=(6,6)) ax = plt.subplot(111) ax.imshow(reconstructed_frames_np[frame_idx * tubelet_size]) ax.set_title(f"Reconstructed Frame: {frame_idx * tubelet_size}") ax.axis('off') return fig_to_image(fig) def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1): """ Compute attention maps for all frames. Returns: (original_frames, attention_maps) as numpy arrays """ inputs = processor(video_frames, return_tensors="pt") pixel_values = inputs['pixel_values'].to(device) batch_size, time, num_channels, height, width = pixel_values.shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size num_patches_per_frame = (height // patch_size) * (width // patch_size) num_temporal_patches = time // tubelet_size if hasattr(model, 'videomae'): encoder_model = model.videomae else: encoder_model = model original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None) encoder_model.config._attn_implementation = "eager" try: outputs = encoder_model(pixel_values, output_attentions=True) finally: if original_attn_impl is not None: encoder_model.config._attn_implementation = original_attn_impl attentions = outputs.attentions if layer_idx < 0: layer_idx = len(attentions) + layer_idx attention_weights = attentions[layer_idx][0] avg_attn = attention_weights.mean(dim=0) dtype = pixel_values.dtype mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] frames_unnorm = pixel_values * std + mean frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() frames_unnorm = np.clip(frames_unnorm, 0, 1) seq_len = avg_attn.shape[0] H_p = height // patch_size W_p = width // patch_size expected_seq_len = num_temporal_patches * num_patches_per_frame if seq_len != expected_seq_len: if seq_len % num_patches_per_frame == 0: num_temporal_patches = seq_len // num_patches_per_frame else: raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}") avg_attn_received = avg_attn.mean(dim=0) attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p) # Create attention maps for all temporal patches attention_maps = [] for frame_idx in range(num_temporal_patches): attn_map = attn_per_patch[frame_idx].detach().cpu().numpy() attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) attn_map_upsampled = cv2.resize(attn_map, (width, height)) attention_maps.append(attn_map_upsampled) return frames_unnorm, attention_maps def compute_latent_all_frames(video_frames, model, processor): """ Compute PCA latent visualizations for all frames. Returns: (original_frames, pca_images) as numpy arrays """ inputs = processor(video_frames, return_tensors="pt") pixel_values = inputs['pixel_values'].to(device) if hasattr(model, 'videomae'): encoder_model = model.videomae else: encoder_model = model outputs = encoder_model(pixel_values, output_hidden_states=True) hidden_states = outputs.last_hidden_state[0] batch_size, time, num_channels, height, width = pixel_values.shape tubelet_size = model.config.tubelet_size patch_size = model.config.patch_size num_patches_per_frame = (height // patch_size) * (width // patch_size) num_temporal_patches = time // tubelet_size dtype = pixel_values.dtype mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None] std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None] frames_unnorm = pixel_values * std + mean frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy() frames_unnorm = np.clip(frames_unnorm, 0, 1) seq_len = hidden_states.shape[0] expected_seq_len = num_temporal_patches * num_patches_per_frame if seq_len != expected_seq_len: if seq_len % num_patches_per_frame == 0: num_temporal_patches = seq_len // num_patches_per_frame else: raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}") hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1) hidden_size = hidden_states_reshaped.shape[-1] hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy() pca = PCA(n_components=3) pca_components = pca.fit_transform(hidden_states_flat) pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3) H_p = int(np.sqrt(num_patches_per_frame)) W_p = H_p if H_p * W_p == num_patches_per_frame: pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) else: factors = [] for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1): if num_patches_per_frame % i == 0: factors.append((i, num_patches_per_frame // i)) if factors: H_p, W_p = factors[-1] pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3) else: raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid") # Normalize components for t in range(num_temporal_patches): for c in range(3): comp = pca_spatial[t, :, :, c] comp_min, comp_max = comp.min(), comp.max() if comp_max > comp_min: pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min) else: pca_spatial[t, :, :, c] = 0.5 # Create upscaled images for all frames upscale_factor = 8 pca_images = [] for t_idx in range(num_temporal_patches): rgb_image = pca_spatial[t_idx] rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST) pca_images.append(rgb_image_upscaled) return frames_unnorm, pca_images # Dummy function for backward compatibility def process_video(video_path): cap = cv2.VideoCapture(video_path) frames = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() visualizations = [cv2.applyColorMap((f * 0.5).astype(np.uint8), cv2.COLORMAP_JET) for f in frames] return frames, visualizations # Global state to store frames after upload stored_frames = [] stored_viz = [] # Cache for visualization results: {video_path: {mode: {frame_idx: image}}} visualization_cache = {} current_video_path = None def on_upload(video_path, mode): global stored_frames, stored_viz, model, processor, visualization_cache, current_video_path if video_path is None: return gr.update(maximum=0), None, None # Initialize model if needed if model is None: model, processor = initialize_model() # Check if we need to recompute (new video or mode not cached) video_path_str = str(video_path) need_to_load_video = (video_path_str != current_video_path) need_to_compute_mode = (video_path_str not in visualization_cache or mode not in visualization_cache[video_path_str]) if need_to_load_video: # Load video frames print(f"Loading video: {video_path_str}") video_frames = load_video(video_path) stored_frames = video_frames current_video_path = video_path_str else: # Reuse already loaded frames video_frames = stored_frames # Initialize cache for this video if video_path_str not in visualization_cache: visualization_cache[video_path_str] = {} if need_to_compute_mode: # Compute all visualizations and cache them print(f"Computing {mode} visualization for all frames...") num_frames = len(stored_frames) tubelet_size = model.config.tubelet_size if mode == "reconstruction": original_frames, reconstructed_frames = compute_reconstruction_all_frames(video_frames, model, processor) # Cache as images per frame - map model frames to stored frames visualization_cache[video_path_str][mode] = {} for i in range(num_frames): # Map stored frame index to model frame index model_frame_idx = min(i, len(reconstructed_frames) - 1) fig = plt.figure(figsize=(6, 6)) ax = plt.subplot(111) ax.imshow(reconstructed_frames[model_frame_idx]) ax.set_title(f"Reconstructed Frame: {i}") ax.axis('off') visualization_cache[video_path_str][mode][i] = fig_to_image(fig) elif mode == "attention": original_frames, attention_maps = compute_attention_all_frames(video_frames, model, processor) visualization_cache[video_path_str][mode] = {} for i in range(num_frames): # Map stored frame to temporal patch temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1) model_frame_idx = min(i, len(original_frames) - 1) if temporal_patch_idx < len(attention_maps): fig = plt.figure(figsize=(6, 6)) ax = plt.subplot(111) ax.imshow(original_frames[model_frame_idx]) ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet') ax.set_title(f"Attention Map - Frame {i}") ax.axis('off') visualization_cache[video_path_str][mode][i] = fig_to_image(fig) elif mode == "latent": original_frames, pca_images = compute_latent_all_frames(video_frames, model, processor) visualization_cache[video_path_str][mode] = {} for i in range(num_frames): # Map stored frame to temporal patch temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1) if temporal_patch_idx < len(pca_images): fig = plt.figure(figsize=(6, 6)) ax = plt.subplot(111) ax.imshow(pca_images[temporal_patch_idx]) ax.set_title(f"PCA Components - Frame {i}") ax.axis('off') visualization_cache[video_path_str][mode][i] = fig_to_image(fig) print(f"Caching complete for {mode} mode") # Load from cache max_idx = len(stored_frames) - 1 frame_idx = 0 # Get original frame if isinstance(stored_frames[0], Image.Image): first_frame = np.array(stored_frames[0]) else: first_frame = stored_frames[0] # Get visualization from cache if video_path_str in visualization_cache and mode in visualization_cache[video_path_str]: if frame_idx in visualization_cache[video_path_str][mode]: viz_img = visualization_cache[video_path_str][mode][frame_idx] else: # Fallback if frame not in cache viz_img = Image.fromarray(first_frame) else: # Fallback if not cached viz_img = Image.fromarray(first_frame) return gr.update(maximum=max_idx, value=0), first_frame, viz_img def update_frame(idx, mode): global stored_frames, visualization_cache, current_video_path if not stored_frames: return None, None frame_idx = int(idx) if frame_idx >= len(stored_frames): frame_idx = len(stored_frames) - 1 # Get frame if isinstance(stored_frames[frame_idx], Image.Image): frame = np.array(stored_frames[frame_idx]) else: frame = stored_frames[frame_idx] # Load visualization from cache (fast!) video_path_str = current_video_path if video_path_str and video_path_str in visualization_cache: if mode in visualization_cache[video_path_str]: if frame_idx in visualization_cache[video_path_str][mode]: viz_img = visualization_cache[video_path_str][mode][frame_idx] else: # Fallback if frame not in cache viz_img = Image.fromarray(frame) else: # Mode not cached, return frame viz_img = Image.fromarray(frame) else: # Not cached, return frame viz_img = Image.fromarray(frame) return frame, viz_img def load_example_video(video_file): def _load_example_video( mode): """Load the predefined example video""" example_path = f"examples/{video_file}" return on_upload(example_path, mode) return _load_example_video # --- Gradio UI Layout --- with gr.Blocks(title="VideoMAE Representation Explorer") as demo: gr.Markdown("## 🎥 VideoMAE Frame-by-Frame Representation Explorer") mode_radio = gr.Radio( choices=["reconstruction", "attention", "latent"], value="reconstruction", label="Visualization Mode", info="Choose the type of visualization" ) with gr.Row(): with gr.Column(): orig_output = gr.Image(label="Original Frame") with gr.Column(): viz_output = gr.Image(label="Representation / Attention") frame_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Frame Index") # Event Listeners video_lists = os.listdir("examples") if os.path.exists("examples") else [] video_lists = os.listdir("app/examples") if os.path.exists("app/examples") else video_lists with gr.Row(): video_input = gr.Video(label="Upload Video") with gr.Column(): for video_file in video_lists: load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary") load_example_btn.click(load_example_video(video_file), inputs=mode_radio, outputs=[frame_slider, orig_output, viz_output]) # load_example_btn = gr.Button("Load Example Video (dog.mp4)", variant="secondary") video_input.change(on_upload, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output]) frame_slider.change(update_frame, inputs=[frame_slider, mode_radio], outputs=[orig_output, viz_output]) def on_mode_change(video_path, mode): """Handle mode change - compute if not cached, otherwise use cache""" global stored_frames, model, processor, visualization_cache, current_video_path if video_path is None: return gr.update(maximum=0), None, None video_path_str = str(video_path) # If video is already loaded and mode is cached, just return cached result if video_path_str == current_video_path and video_path_str in visualization_cache: if mode in visualization_cache[video_path_str]: max_idx = len(stored_frames) - 1 frame_idx = 0 if isinstance(stored_frames[0], Image.Image): first_frame = np.array(stored_frames[0]) else: first_frame = stored_frames[0] if frame_idx in visualization_cache[video_path_str][mode]: viz_img = visualization_cache[video_path_str][mode][frame_idx] else: viz_img = Image.fromarray(first_frame) return gr.update(maximum=max_idx, value=0), first_frame, viz_img # Otherwise, compute (will use cached video frames if available) return on_upload(video_path, mode) mode_radio.change(on_mode_change, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output]) if __name__ == "__main__": # Initialize model at startup initialize_model() demo.launch()