erow's picture
revert
9adc4ed
import os
# Import spaces before torch to avoid CUDA initialization error
try:
import spaces
except ImportError:
pass
import gradio as gr
import cv2
import numpy as np
import torch
from PIL import Image
from pathlib import Path
from sklearn.decomposition import PCA
from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining, VideoMAEModel
from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import io
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Helper function to convert matplotlib figure to PIL Image
def fig_to_image(fig):
"""Convert matplotlib figure to PIL Image"""
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
return img
def load_video(video_path, num_frames=16, sample_rate=4):
"""
Load video from file path.
Returns list of PIL Images or numpy arrays.
"""
video_path = Path(video_path)
if not video_path.exists():
raise FileNotFoundError(f"Video file not found: {video_path}")
# Try to load as video file
cap = cv2.VideoCapture(str(video_path))
frames = []
if not cap.isOpened():
raise ValueError(f"Could not open video file: {video_path}")
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if sample_rate * num_frames > frame_count:
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
print(f"warning: only {num_frames} frames are sampled from {frame_count} frames")
else:
frame_indices = np.arange(0, sample_rate * num_frames, sample_rate)
print(f"Sampling {frame_indices}")
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
print(f"Loaded {len(frames)} frames")
cap.release()
return frames
def load_model(model_name, model_type='pretraining'):
"""
Load model and processor by name.
model_type: 'pretraining' for VideoMAEForPreTraining, 'base' for VideoMAEModel
"""
processor = VideoMAEImageProcessor.from_pretrained(model_name)
if model_type == 'base':
model = VideoMAEModel.from_pretrained(model_name)
else:
model = VideoMAEForPreTraining.from_pretrained(model_name)
model = model.to(device)
return model, processor
# Global model and processor
model = None
processor = None
def initialize_model(model_name='MCG-NJU/videomae-base'):
"""Initialize the model (call once at startup)"""
global model, processor
if model is None:
print(f"Loading model: {model_name}")
model, processor = load_model(model_name)
print("Model loaded successfully")
return model, processor
def visualize_attention(video_frames, model, processor, layer_idx=-1):
"""
Visualize attention maps from VideoMAE model.
Returns PIL Image for Gradio.
"""
inputs = processor(video_frames, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
batch_size, time, num_channels, height, width = pixel_values.shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
num_patches_per_frame = (height // patch_size) * (width // patch_size)
num_temporal_patches = time // tubelet_size
# Use VideoMAEModel to get attention weights
if hasattr(model, 'videomae'):
encoder_model = model.videomae
else:
encoder_model = model
# Disable SDPA and use eager attention
original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
encoder_model.config._attn_implementation = "eager"
try:
outputs = encoder_model(pixel_values, output_attentions=True)
finally:
if original_attn_impl is not None:
encoder_model.config._attn_implementation = original_attn_impl
attentions = outputs.attentions
if layer_idx < 0:
layer_idx = len(attentions) + layer_idx
attention_weights = attentions[layer_idx][0]
avg_attn = attention_weights.mean(dim=0)
# Unnormalize frames
dtype = pixel_values.dtype
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
frames_unnorm = pixel_values * std + mean
frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
frames_unnorm = np.clip(frames_unnorm, 0, 1)
seq_len = avg_attn.shape[0]
H_p = height // patch_size
W_p = width // patch_size
expected_seq_len = num_temporal_patches * num_patches_per_frame
if seq_len != expected_seq_len:
if seq_len % num_patches_per_frame == 0:
num_temporal_patches = seq_len // num_patches_per_frame
else:
raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}")
avg_attn_received = avg_attn.mean(dim=0)
attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
# Create visualization for first frame
frame_idx = 0
frame_img = frames_unnorm[frame_idx * tubelet_size]
attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
attn_map_upsampled = cv2.resize(attn_map, (width, height))
# Create overlay
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(frame_img)
ax.imshow(attn_map_upsampled, alpha=0.5, cmap='jet')
ax.set_title(f"Attention Map - Frame {frame_idx * tubelet_size}")
ax.axis('off')
return fig_to_image(fig)
def visualize_latent(video_frames, model, processor):
"""
Visualize latent space representations from VideoMAE model.
Returns PIL Image for Gradio.
"""
inputs = processor(video_frames, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
if hasattr(model, 'videomae'):
encoder_model = model.videomae
else:
encoder_model = model
outputs = encoder_model(pixel_values, output_hidden_states=True)
hidden_states = outputs.last_hidden_state[0]
batch_size, time, num_channels, height, width = pixel_values.shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
num_patches_per_frame = (height // patch_size) * (width // patch_size)
num_temporal_patches = time // tubelet_size
# Unnormalize frames
dtype = pixel_values.dtype
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
frames_unnorm = pixel_values * std + mean
frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
frames_unnorm = np.clip(frames_unnorm, 0, 1)
seq_len = hidden_states.shape[0]
expected_seq_len = num_temporal_patches * num_patches_per_frame
if seq_len != expected_seq_len:
if seq_len % num_patches_per_frame == 0:
num_temporal_patches = seq_len // num_patches_per_frame
else:
raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
hidden_size = hidden_states_reshaped.shape[-1]
hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
pca = PCA(n_components=3)
pca_components = pca.fit_transform(hidden_states_flat)
pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
H_p = int(np.sqrt(num_patches_per_frame))
W_p = H_p
if H_p * W_p == num_patches_per_frame:
pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
else:
factors = []
for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
if num_patches_per_frame % i == 0:
factors.append((i, num_patches_per_frame // i))
if factors:
H_p, W_p = factors[-1]
pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
else:
raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
# Normalize components
for t in range(num_temporal_patches):
for c in range(3):
comp = pca_spatial[t, :, :, c]
comp_min, comp_max = comp.min(), comp.max()
if comp_max > comp_min:
pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
else:
pca_spatial[t, :, :, c] = 0.5
# Show first frame
frame_idx = 0
frame_img = frames_unnorm[frame_idx * tubelet_size]
rgb_image = pca_spatial[frame_idx]
upscale_factor = 8
rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(1, 1, 1)
ax.imshow(rgb_image_upscaled)
ax.set_title(f"PCA Components (RGB = PC1, PC2, PC3)")
ax.axis('off')
plt.suptitle(f"Explained Variance: {pca.explained_variance_ratio_.sum():.2%}", fontsize=12)
plt.tight_layout()
return fig_to_image(fig)
def compute_reconstruction_all_frames(video_frames, model, processor):
"""
Compute reconstruction for all frames and return as numpy arrays.
Returns: (original_frames, reconstructed_frames) as numpy arrays
"""
inputs = processor(video_frames, return_tensors="pt")
T, C, H, W = inputs['pixel_values'][0].shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
T = T//tubelet_size
num_patches = (model.config.image_size // model.config.patch_size) ** 2
num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size))
total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches
batch_size = inputs['pixel_values'].shape[0]
bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool)
for b in range(batch_size):
mask_indices = np.random.choice(total_patches, num_masked, replace=False)
bool_masked_pos[b, mask_indices] = True
inputs['bool_masked_pos'] = bool_masked_pos.to(device)
inputs['pixel_values'] = inputs['pixel_values'].to(device)
outputs = model(**inputs)
logits = outputs.logits
pixel_values = inputs['pixel_values']
batch_size, time, num_channels, height, width = pixel_values.shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
num_patches_per_frame = (height // patch_size) * (width // patch_size)
num_temporal_patches = time // tubelet_size
total_patches = num_temporal_patches * num_patches_per_frame
dtype = pixel_values.dtype
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
frames_unnorm = pixel_values * std + mean
frames_patched = frames_unnorm.view(
batch_size, time // tubelet_size, tubelet_size, num_channels,
height // patch_size, patch_size, width // patch_size, patch_size,
)
frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
videos_patch = frames_patched.view(
batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels,
)
if model.config.norm_pix_loss:
patch_mean = videos_patch.mean(dim=-2, keepdim=True)
patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
logits_denorm = logits * patch_std + patch_mean
else:
logits_denorm = torch.clamp(logits, 0.0, 1.0)
reconstructed_patches = videos_patch.clone()
reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels)
reconstructed_patches_reshaped = reconstructed_patches.view(
batch_size, time // tubelet_size, height // patch_size, width // patch_size,
tubelet_size, patch_size, patch_size, num_channels,
)
reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous()
reconstructed_frames = reconstructed_patches_reshaped.view(
batch_size, time, num_channels, height, width,
)
original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy()
original_frames = np.clip(original_frames, 0, 1)
reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1)
return original_frames, reconstructed_frames_np
def visualize_reconstruction(video_frames, model, processor):
"""
Visualize reconstruction from VideoMAE model.
Returns PIL Image for Gradio.
"""
inputs = processor(video_frames, return_tensors="pt")
T, C, H, W = inputs['pixel_values'][0].shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
T = T//tubelet_size
num_patches = (model.config.image_size // model.config.patch_size) ** 2
num_masked = int(0.9 * num_patches * (model.config.num_frames // model.config.tubelet_size))
total_patches = (model.config.num_frames // model.config.tubelet_size) * num_patches
batch_size = inputs['pixel_values'].shape[0]
bool_masked_pos = torch.zeros((batch_size, total_patches), dtype=torch.bool)
for b in range(batch_size):
mask_indices = np.random.choice(total_patches, num_masked, replace=False)
bool_masked_pos[b, mask_indices] = True
inputs['bool_masked_pos'] = bool_masked_pos.to(device)
inputs['pixel_values'] = inputs['pixel_values'].to(device)
outputs = model(**inputs)
logits = outputs.logits
pixel_values = inputs['pixel_values']
batch_size, time, num_channels, height, width = pixel_values.shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
num_patches_per_frame = (height // patch_size) * (width // patch_size)
num_temporal_patches = time // tubelet_size
total_patches = num_temporal_patches * num_patches_per_frame
dtype = pixel_values.dtype
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
frames_unnorm = pixel_values * std + mean
frames_patched = frames_unnorm.view(
batch_size, time // tubelet_size, tubelet_size, num_channels,
height // patch_size, patch_size, width // patch_size, patch_size,
)
frames_patched = frames_patched.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
videos_patch = frames_patched.view(
batch_size, total_patches, tubelet_size * patch_size * patch_size * num_channels,
)
if model.config.norm_pix_loss:
patch_mean = videos_patch.mean(dim=-2, keepdim=True)
patch_std = (videos_patch.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
logits_denorm = logits * patch_std + patch_mean
else:
logits_denorm = torch.clamp(logits, 0.0, 1.0)
reconstructed_patches = videos_patch.clone()
reconstructed_patches[bool_masked_pos] = logits_denorm.reshape(-1, tubelet_size * patch_size * patch_size * num_channels)
reconstructed_patches_reshaped = reconstructed_patches.view(
batch_size, time // tubelet_size, height // patch_size, width // patch_size,
tubelet_size, patch_size, patch_size, num_channels,
)
reconstructed_patches_reshaped = reconstructed_patches_reshaped.permute(0, 1, 4, 7, 2, 5, 3, 6).contiguous()
reconstructed_frames = reconstructed_patches_reshaped.view(
batch_size, time, num_channels, height, width,
)
original_frames = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
reconstructed_frames_np = reconstructed_frames[0].permute(0, 2, 3, 1).detach().cpu().numpy()
original_frames = np.clip(original_frames, 0, 1)
reconstructed_frames_np = np.clip(reconstructed_frames_np, 0, 1)
# Show first frame
frame_idx = 0
fig = plt.figure(figsize=(6,6))
ax = plt.subplot(111)
ax.imshow(reconstructed_frames_np[frame_idx * tubelet_size])
ax.set_title(f"Reconstructed Frame: {frame_idx * tubelet_size}")
ax.axis('off')
return fig_to_image(fig)
def compute_attention_all_frames(video_frames, model, processor, layer_idx=-1):
"""
Compute attention maps for all frames.
Returns: (original_frames, attention_maps) as numpy arrays
"""
inputs = processor(video_frames, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
batch_size, time, num_channels, height, width = pixel_values.shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
num_patches_per_frame = (height // patch_size) * (width // patch_size)
num_temporal_patches = time // tubelet_size
if hasattr(model, 'videomae'):
encoder_model = model.videomae
else:
encoder_model = model
original_attn_impl = getattr(encoder_model.config, '_attn_implementation', None)
encoder_model.config._attn_implementation = "eager"
try:
outputs = encoder_model(pixel_values, output_attentions=True)
finally:
if original_attn_impl is not None:
encoder_model.config._attn_implementation = original_attn_impl
attentions = outputs.attentions
if layer_idx < 0:
layer_idx = len(attentions) + layer_idx
attention_weights = attentions[layer_idx][0]
avg_attn = attention_weights.mean(dim=0)
dtype = pixel_values.dtype
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
frames_unnorm = pixel_values * std + mean
frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
frames_unnorm = np.clip(frames_unnorm, 0, 1)
seq_len = avg_attn.shape[0]
H_p = height // patch_size
W_p = width // patch_size
expected_seq_len = num_temporal_patches * num_patches_per_frame
if seq_len != expected_seq_len:
if seq_len % num_patches_per_frame == 0:
num_temporal_patches = seq_len // num_patches_per_frame
else:
raise ValueError(f"Cannot reshape attention: seq_len={seq_len}, expected={expected_seq_len}")
avg_attn_received = avg_attn.mean(dim=0)
attn_per_patch = avg_attn_received.reshape(num_temporal_patches, H_p, W_p)
# Create attention maps for all temporal patches
attention_maps = []
for frame_idx in range(num_temporal_patches):
attn_map = attn_per_patch[frame_idx].detach().cpu().numpy()
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
attn_map_upsampled = cv2.resize(attn_map, (width, height))
attention_maps.append(attn_map_upsampled)
return frames_unnorm, attention_maps
def compute_latent_all_frames(video_frames, model, processor):
"""
Compute PCA latent visualizations for all frames.
Returns: (original_frames, pca_images) as numpy arrays
"""
inputs = processor(video_frames, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
if hasattr(model, 'videomae'):
encoder_model = model.videomae
else:
encoder_model = model
outputs = encoder_model(pixel_values, output_hidden_states=True)
hidden_states = outputs.last_hidden_state[0]
batch_size, time, num_channels, height, width = pixel_values.shape
tubelet_size = model.config.tubelet_size
patch_size = model.config.patch_size
num_patches_per_frame = (height // patch_size) * (width // patch_size)
num_temporal_patches = time // tubelet_size
dtype = pixel_values.dtype
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
frames_unnorm = pixel_values * std + mean
frames_unnorm = frames_unnorm[0].permute(0, 2, 3, 1).detach().cpu().numpy()
frames_unnorm = np.clip(frames_unnorm, 0, 1)
seq_len = hidden_states.shape[0]
expected_seq_len = num_temporal_patches * num_patches_per_frame
if seq_len != expected_seq_len:
if seq_len % num_patches_per_frame == 0:
num_temporal_patches = seq_len // num_patches_per_frame
else:
raise ValueError(f"Cannot reshape hidden states: seq_len={seq_len}, expected={expected_seq_len}")
hidden_states_reshaped = hidden_states.reshape(num_temporal_patches, num_patches_per_frame, -1)
hidden_size = hidden_states_reshaped.shape[-1]
hidden_states_flat = hidden_states_reshaped.reshape(-1, hidden_size).detach().cpu().numpy()
pca = PCA(n_components=3)
pca_components = pca.fit_transform(hidden_states_flat)
pca_reshaped = pca_components.reshape(num_temporal_patches, num_patches_per_frame, 3)
H_p = int(np.sqrt(num_patches_per_frame))
W_p = H_p
if H_p * W_p == num_patches_per_frame:
pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
else:
factors = []
for i in range(1, int(np.sqrt(num_patches_per_frame)) + 1):
if num_patches_per_frame % i == 0:
factors.append((i, num_patches_per_frame // i))
if factors:
H_p, W_p = factors[-1]
pca_spatial = pca_reshaped.reshape(num_temporal_patches, H_p, W_p, 3)
else:
raise ValueError(f"Cannot reshape {num_patches_per_frame} patches into a 2D grid")
# Normalize components
for t in range(num_temporal_patches):
for c in range(3):
comp = pca_spatial[t, :, :, c]
comp_min, comp_max = comp.min(), comp.max()
if comp_max > comp_min:
pca_spatial[t, :, :, c] = (comp - comp_min) / (comp_max - comp_min)
else:
pca_spatial[t, :, :, c] = 0.5
# Create upscaled images for all frames
upscale_factor = 8
pca_images = []
for t_idx in range(num_temporal_patches):
rgb_image = pca_spatial[t_idx]
rgb_image_upscaled = cv2.resize(rgb_image, (W_p * upscale_factor, H_p * upscale_factor), interpolation=cv2.INTER_NEAREST)
pca_images.append(rgb_image_upscaled)
return frames_unnorm, pca_images
# Dummy function for backward compatibility
def process_video(video_path):
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
visualizations = [cv2.applyColorMap((f * 0.5).astype(np.uint8), cv2.COLORMAP_JET) for f in frames]
return frames, visualizations
# Global state to store frames after upload
stored_frames = []
stored_viz = []
# Cache for visualization results: {video_path: {mode: {frame_idx: image}}}
visualization_cache = {}
current_video_path = None
def on_upload(video_path, mode):
global stored_frames, stored_viz, model, processor, visualization_cache, current_video_path
if video_path is None:
return gr.update(maximum=0), None, None
# Initialize model if needed
if model is None:
model, processor = initialize_model()
# Check if we need to recompute (new video or mode not cached)
video_path_str = str(video_path)
need_to_load_video = (video_path_str != current_video_path)
need_to_compute_mode = (video_path_str not in visualization_cache or mode not in visualization_cache[video_path_str])
if need_to_load_video:
# Load video frames
print(f"Loading video: {video_path_str}")
video_frames = load_video(video_path)
stored_frames = video_frames
current_video_path = video_path_str
else:
# Reuse already loaded frames
video_frames = stored_frames
# Initialize cache for this video
if video_path_str not in visualization_cache:
visualization_cache[video_path_str] = {}
if need_to_compute_mode:
# Compute all visualizations and cache them
print(f"Computing {mode} visualization for all frames...")
num_frames = len(stored_frames)
tubelet_size = model.config.tubelet_size
if mode == "reconstruction":
original_frames, reconstructed_frames = compute_reconstruction_all_frames(video_frames, model, processor)
# Cache as images per frame - map model frames to stored frames
visualization_cache[video_path_str][mode] = {}
for i in range(num_frames):
# Map stored frame index to model frame index
model_frame_idx = min(i, len(reconstructed_frames) - 1)
fig = plt.figure(figsize=(6, 6))
ax = plt.subplot(111)
ax.imshow(reconstructed_frames[model_frame_idx])
ax.set_title(f"Reconstructed Frame: {i}")
ax.axis('off')
visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
elif mode == "attention":
original_frames, attention_maps = compute_attention_all_frames(video_frames, model, processor)
visualization_cache[video_path_str][mode] = {}
for i in range(num_frames):
# Map stored frame to temporal patch
temporal_patch_idx = min(i // tubelet_size, len(attention_maps) - 1)
model_frame_idx = min(i, len(original_frames) - 1)
if temporal_patch_idx < len(attention_maps):
fig = plt.figure(figsize=(6, 6))
ax = plt.subplot(111)
ax.imshow(original_frames[model_frame_idx])
ax.imshow(attention_maps[temporal_patch_idx], alpha=0.5, cmap='jet')
ax.set_title(f"Attention Map - Frame {i}")
ax.axis('off')
visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
elif mode == "latent":
original_frames, pca_images = compute_latent_all_frames(video_frames, model, processor)
visualization_cache[video_path_str][mode] = {}
for i in range(num_frames):
# Map stored frame to temporal patch
temporal_patch_idx = min(i // tubelet_size, len(pca_images) - 1)
if temporal_patch_idx < len(pca_images):
fig = plt.figure(figsize=(6, 6))
ax = plt.subplot(111)
ax.imshow(pca_images[temporal_patch_idx])
ax.set_title(f"PCA Components - Frame {i}")
ax.axis('off')
visualization_cache[video_path_str][mode][i] = fig_to_image(fig)
print(f"Caching complete for {mode} mode")
# Load from cache
max_idx = len(stored_frames) - 1
frame_idx = 0
# Get original frame
if isinstance(stored_frames[0], Image.Image):
first_frame = np.array(stored_frames[0])
else:
first_frame = stored_frames[0]
# Get visualization from cache
if video_path_str in visualization_cache and mode in visualization_cache[video_path_str]:
if frame_idx in visualization_cache[video_path_str][mode]:
viz_img = visualization_cache[video_path_str][mode][frame_idx]
else:
# Fallback if frame not in cache
viz_img = Image.fromarray(first_frame)
else:
# Fallback if not cached
viz_img = Image.fromarray(first_frame)
return gr.update(maximum=max_idx, value=0), first_frame, viz_img
def update_frame(idx, mode):
global stored_frames, visualization_cache, current_video_path
if not stored_frames:
return None, None
frame_idx = int(idx)
if frame_idx >= len(stored_frames):
frame_idx = len(stored_frames) - 1
# Get frame
if isinstance(stored_frames[frame_idx], Image.Image):
frame = np.array(stored_frames[frame_idx])
else:
frame = stored_frames[frame_idx]
# Load visualization from cache (fast!)
video_path_str = current_video_path
if video_path_str and video_path_str in visualization_cache:
if mode in visualization_cache[video_path_str]:
if frame_idx in visualization_cache[video_path_str][mode]:
viz_img = visualization_cache[video_path_str][mode][frame_idx]
else:
# Fallback if frame not in cache
viz_img = Image.fromarray(frame)
else:
# Mode not cached, return frame
viz_img = Image.fromarray(frame)
else:
# Not cached, return frame
viz_img = Image.fromarray(frame)
return frame, viz_img
def load_example_video(video_file):
def _load_example_video( mode):
"""Load the predefined example video"""
example_path = f"examples/{video_file}"
return on_upload(example_path, mode)
return _load_example_video
# --- Gradio UI Layout ---
with gr.Blocks(title="VideoMAE Representation Explorer") as demo:
gr.Markdown("## 🎥 VideoMAE Frame-by-Frame Representation Explorer")
mode_radio = gr.Radio(
choices=["reconstruction", "attention", "latent"],
value="reconstruction",
label="Visualization Mode",
info="Choose the type of visualization"
)
with gr.Row():
with gr.Column():
orig_output = gr.Image(label="Original Frame")
with gr.Column():
viz_output = gr.Image(label="Representation / Attention")
frame_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Frame Index")
# Event Listeners
video_lists = os.listdir("examples") if os.path.exists("examples") else []
video_lists = os.listdir("app/examples") if os.path.exists("app/examples") else video_lists
with gr.Row():
video_input = gr.Video(label="Upload Video")
with gr.Column():
for video_file in video_lists:
load_example_btn = gr.Button(f"Load Example Video ({video_file})", variant="secondary")
load_example_btn.click(load_example_video(video_file), inputs=mode_radio, outputs=[frame_slider, orig_output, viz_output])
# load_example_btn = gr.Button("Load Example Video (dog.mp4)", variant="secondary")
video_input.change(on_upload, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
frame_slider.change(update_frame, inputs=[frame_slider, mode_radio], outputs=[orig_output, viz_output])
def on_mode_change(video_path, mode):
"""Handle mode change - compute if not cached, otherwise use cache"""
global stored_frames, model, processor, visualization_cache, current_video_path
if video_path is None:
return gr.update(maximum=0), None, None
video_path_str = str(video_path)
# If video is already loaded and mode is cached, just return cached result
if video_path_str == current_video_path and video_path_str in visualization_cache:
if mode in visualization_cache[video_path_str]:
max_idx = len(stored_frames) - 1
frame_idx = 0
if isinstance(stored_frames[0], Image.Image):
first_frame = np.array(stored_frames[0])
else:
first_frame = stored_frames[0]
if frame_idx in visualization_cache[video_path_str][mode]:
viz_img = visualization_cache[video_path_str][mode][frame_idx]
else:
viz_img = Image.fromarray(first_frame)
return gr.update(maximum=max_idx, value=0), first_frame, viz_img
# Otherwise, compute (will use cached video frames if available)
return on_upload(video_path, mode)
mode_radio.change(on_mode_change, inputs=[video_input, mode_radio], outputs=[frame_slider, orig_output, viz_output])
if __name__ == "__main__":
# Initialize model at startup
initialize_model()
demo.launch()