Spaces:
Running on Zero
Running on Zero
| # IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.) | |
| try: | |
| import spaces | |
| ZEROGPU_AVAILABLE = True | |
| except ImportError: | |
| ZEROGPU_AVAILABLE = False | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import av | |
| import imageio | |
| from transformers import VivitImageProcessor | |
| from PIL import Image, ImageDraw, ImageFont | |
| from omegaconf import OmegaConf | |
| from einops import rearrange | |
| from tqdm import trange | |
| from autogaze.models.autogaze import AutoGaze | |
| from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch | |
| from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction | |
| from autogaze.utils import UnNormalize | |
| def image_to_video(image_path, output_path, fps): | |
| """ | |
| Convert a single image to a single-frame video file. | |
| Args: | |
| image_path: Path to input image | |
| output_path: Path to output video file | |
| fps: Frame rate for the video | |
| Returns: | |
| Dictionary with video metadata (width, height, frames, fps) | |
| """ | |
| img = Image.open(image_path) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| img_array = np.array(img) | |
| with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv420p', macro_block_size=1) as writer: | |
| writer.append_data(img_array) | |
| return { | |
| 'width': img_array.shape[1], | |
| 'height': img_array.shape[0], | |
| 'frames': 1, | |
| 'fps': fps | |
| } | |
| def load_model(device='cuda'): | |
| print("Loading AutoGaze model from HuggingFace...") | |
| model = AutoGaze.from_pretrained("nvidia/AutoGaze") | |
| model = model.to(device) | |
| model.eval() | |
| transform = VivitImageProcessor.from_pretrained( | |
| "facebook/vit-mae-large", | |
| size=model.scales[-1], | |
| crop_size=model.scales[-1] | |
| ) | |
| unnorm = UnNormalize( | |
| mean=transform.image_mean, | |
| std=transform.image_std, | |
| rescale_factor=transform.rescale_factor | |
| ) | |
| print("Loading VideoMAE model from HuggingFace...") | |
| scales_str = '+'.join(map(str, model.scales)) | |
| recon_model_config = OmegaConf.create({ | |
| 'scale_embed': True, | |
| 'max_num_frames': 256, | |
| 'time_embed': True, | |
| 'causal': True, | |
| 'loss_type': 'l1+dinov2_reg+siglip2', | |
| 'loss_weights': '1', | |
| 'l1_loss_config': {}, | |
| 'dinov2_reg_loss_config': { | |
| 'model': 'facebook/dinov2-with-registers-base' | |
| }, | |
| 'siglip2_loss_config': { | |
| 'model': 'google/siglip2-base-patch16-224' | |
| } | |
| }) | |
| task = VideoMAEReconstruction( | |
| recon_model='facebook/vit-mae-large', | |
| recon_model_config=recon_model_config, | |
| scales=scales_str, | |
| recon_sample_rate=1, | |
| attn_mode='sdpa' | |
| ) | |
| # Load fine-tuned weights from HuggingFace | |
| from huggingface_hub import hf_hub_download | |
| checkpoint_path = hf_hub_download(repo_id="bfshi/VideoMAE_AutoGaze", filename="videomae.pt") | |
| print(f"Loading VideoMAE checkpoint from {checkpoint_path}...") | |
| task_sd = torch.load(checkpoint_path, map_location='cpu') | |
| task_sd = {k.replace('module.mae.', ''): v for k, v in task_sd.items()} | |
| task.mae.load_state_dict(task_sd, strict=True) | |
| print("Loaded VideoMAE checkpoint from HuggingFace") | |
| task = task.to(device) | |
| task.eval() | |
| return { | |
| 'model': model, | |
| 'task': task, | |
| 'unnorm': unnorm, | |
| 'scales': model.scales, | |
| 'transform': transform, | |
| } | |
| def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.6, progress_callback=None, spatial_batch_size=16): | |
| """ | |
| Process a video file with AutoGaze using chunking for any resolution/duration. | |
| Args: | |
| video_path: Path to video file | |
| setup: Dictionary with model, task, unnorm, scales, transform | |
| gazing_ratio: Maximum percentage of patches to gaze per frame | |
| task_loss_requirement: Reconstruction loss threshold | |
| progress_callback: Optional callback function for progress updates | |
| Yields: | |
| Dictionary with original frames, gazing frames, reconstruction frames, and statistics | |
| """ | |
| model = setup['model'] | |
| task = setup['task'] | |
| transform = setup['transform'] | |
| device = next(model.parameters()).device | |
| if device == 'cuda': | |
| torch.cuda.empty_cache() | |
| container = av.open(video_path) | |
| video_stream = container.streams.video[0] | |
| total_frames_available = video_stream.frames | |
| fps = float(video_stream.average_rate) | |
| container.close() | |
| container = av.open(video_path) | |
| sample_indices = list(range(total_frames_available)) | |
| video = read_video_pyav(container=container, indices=sample_indices) # (T, H, W, 3) numpy array | |
| container.close() | |
| # Keep video on CPU for preprocessing to save GPU memory | |
| video_tensor = torch.from_numpy(video).float() # (T, H, W, 3) | |
| video_tensor = video_tensor / 255.0 # Normalize to [0, 1] | |
| video_tensor = video_tensor.permute(0, 3, 1, 2) # (T, C, H, W) | |
| T, C, H, W = video_tensor.shape | |
| if T > 200: | |
| print(f'Video has {T} frames, which may require significant GPU memory. Decreasing spatial_batch_size to 2.') | |
| spatial_batch_size //= 2 | |
| # Clone for later visualization (keep on CPU) | |
| video_tensor_original = video_tensor.clone() | |
| # Pad video to be divisible by 224x224 and 16 frames | |
| pad_t = (16 - T % 16) % 16 | |
| pad_h = (224 - H % 224) % 224 | |
| pad_w = (224 - W % 224) % 224 | |
| if pad_t > 0 or pad_h > 0 or pad_w > 0: | |
| video_tensor = F.pad(video_tensor, (0, pad_w, 0, pad_h, 0, 0, 0, pad_t)) | |
| # Chunk video into 16-frame, 224x224 chunks (following QUICK_START.md) | |
| video_tensor = video_tensor.unsqueeze(0) # 1 * T * C * H * W | |
| # Calculate chunking dimensions | |
| nt = (T + pad_t) // 16 | |
| nh = (H + pad_h) // 224 | |
| nw = (W + pad_w) // 224 | |
| num_spatial_chunks = nh * nw | |
| num_chunks = nt * num_spatial_chunks | |
| # Chunk into (num_chunks, 16, C, 224, 224) | |
| video_chunks = rearrange(video_tensor, 'B (nt t) C (nh h) (nw w) -> (B nt nh nw) t C h w', t=16, h=224, w=224) | |
| print(f"Video chunked into {num_chunks} chunks ({nt} temporal x {num_spatial_chunks} spatial) of shape (16, {C}, 224, 224). Original shape: ({T}, {C}, {H}, {W})") | |
| # Apply VivitImageProcessor normalization to chunks | |
| # Rearrange chunks to process all frames: (num_chunks, 16, C, H, W) -> (num_chunks * 16, C, H, W) | |
| chunks_flat = rearrange(video_chunks, 'b t c h w -> (b t) c h w') | |
| # Apply normalization using VivitImageProcessor's mean and std (on CPU) | |
| mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1) | |
| std = torch.tensor(transform.image_std).view(1, 3, 1, 1) | |
| chunks_flat = (chunks_flat - mean) / std | |
| video_chunks = rearrange(chunks_flat, '(b t) c h w -> b t c h w', b=num_chunks, t=16) | |
| video_chunks = rearrange(video_chunks, '(ns nt) t c h w -> ns nt t c h w', ns=num_spatial_chunks, nt=nt) | |
| # Keep video_chunks on CPU - only move mini-batches to GPU as needed | |
| print(f'video_chunks shape (spatial, temporal, frames, C, H, W): {video_chunks.shape}') | |
| del video_tensor, chunks_flat, mean, std | |
| with torch.inference_mode(): | |
| # Process spatial locations in mini-batches (keep all temporal chunks together per spatial location) | |
| num_spatial_batches = (num_spatial_chunks + spatial_batch_size - 1) // spatial_batch_size | |
| all_gaze_outputs = [] | |
| total_gazing_tokens = 0 | |
| for batch_idx in range(num_spatial_batches): | |
| start_idx = batch_idx * spatial_batch_size | |
| end_idx = min(start_idx + spatial_batch_size, num_spatial_chunks) | |
| batch_size = end_idx - start_idx | |
| gazing_pct = int(((batch_idx + 1) / num_spatial_batches) * 100) | |
| if progress_callback: | |
| progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%") | |
| yield None | |
| spatial_batch = video_chunks[start_idx:end_idx].to(device) | |
| spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w') | |
| print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks') | |
| # Run AutoGaze on this mini-batch | |
| batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement) | |
| num_gazing_each_frame = batch_gaze_output['num_gazing_each_frame'][:T] | |
| num_gazing_total = num_gazing_each_frame.sum().item() | |
| # Free GPU memory after forward pass | |
| del spatial_batch | |
| # Count gazing tokens for this batch | |
| if_padded = batch_gaze_output.get('if_padded_gazing') | |
| if if_padded is not None: | |
| print(f'shape of if_padded: {if_padded.shape}') | |
| if_padded = if_padded[:, :min(num_gazing_total, if_padded.shape[1])] | |
| new_gazing_tokens = (~if_padded).sum().item() | |
| else: | |
| new_gazing_tokens = (batch_gaze_output['gazing_pos'] < (196 * T)).sum().item() | |
| total_gazing_tokens += new_gazing_tokens | |
| print(f'Batch {batch_idx+1}: Gazing tokens = {new_gazing_tokens}, Total gazing tokens so far = {total_gazing_tokens}') | |
| # Store the output | |
| all_gaze_outputs.append(batch_gaze_output) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("Merging mini-batch results...") | |
| # Find max sequence length across all mini-batches | |
| max_seq_len = max(out['gazing_pos'].shape[1] for out in all_gaze_outputs) | |
| # Pad gazing_pos and if_padded_gazing to same length (they have variable seq length) | |
| # gazing_mask doesn't need padding since all chunks have same shape | |
| padded_gazing_pos = [] | |
| padded_if_padded_gazing = [] | |
| for out in all_gaze_outputs: | |
| seq_len = out['gazing_pos'].shape[1] | |
| pad_len = max_seq_len - seq_len | |
| # Pad gazing_pos with zeros | |
| padded_pos = F.pad(out['gazing_pos'], (0, pad_len), value=0) | |
| padded_gazing_pos.append(padded_pos) | |
| # Pad if_padded_gazing and mark new positions as True (padded) | |
| if 'if_padded_gazing' in out: | |
| padded_if_pad = F.pad(out['if_padded_gazing'], (0, pad_len), value=True) | |
| padded_if_padded_gazing.append(padded_if_pad) | |
| # Store num_gazing_each_frame per mini-batch for later per-chunk extraction | |
| num_gazing_each_frame_list = [out['num_gazing_each_frame'] for out in all_gaze_outputs] | |
| batch_sizes = [out['gazing_pos'].shape[0] for out in all_gaze_outputs] | |
| gaze_output = { | |
| 'gazing_pos': torch.cat(padded_gazing_pos, dim=0), | |
| 'gazing_mask': [torch.cat([out['gazing_mask'][i] for out in all_gaze_outputs], dim=0) for i in range(4)], | |
| 'num_gazing_each_frame_list': num_gazing_each_frame_list, # List of values per mini-batch | |
| 'batch_sizes': batch_sizes, # Track which chunks came from which mini-batch | |
| 'frame_sampling_rate': all_gaze_outputs[0]['frame_sampling_rate'], | |
| 'num_vision_tokens_each_frame': all_gaze_outputs[0]['num_vision_tokens_each_frame'], | |
| } | |
| if len(padded_if_padded_gazing) > 0: | |
| gaze_output['if_padded_gazing'] = torch.cat(padded_if_padded_gazing, dim=0) | |
| # Clean up mini-batch outputs | |
| del all_gaze_outputs | |
| total_possible_tokens = 196 * min(T, 16) * num_chunks | |
| # Extract gazing masks for later visualization (already in batched form) | |
| gazing_masks_batched = gaze_output['gazing_mask'] # List of 4 scales, each (num_chunks, 16, num_patches) | |
| # Flatten video_chunks back to (num_chunks, 16, C, H, W) for reconstruction | |
| video_chunks_flat = rearrange(video_chunks, 'ns nt t c h w -> (ns nt) t c h w').cpu() | |
| # Pre-allocate reconstruction tensor on CPU to avoid memory accumulation | |
| total_frames = num_chunks * 16 | |
| C = video_chunks_flat.shape[2] | |
| reconstruction_chunks = torch.zeros((total_frames, C, 224, 224), dtype=torch.float32) | |
| frame_idx_counter = 0 | |
| # Process reconstruction in mini-batches matching AutoGaze batch structure | |
| num_autogaze_batches = len(gaze_output['num_gazing_each_frame_list']) | |
| print(f'Reconstructing {num_chunks} chunks in {num_autogaze_batches} batches (aligned with AutoGaze batches)...') | |
| chunk_idx = 0 | |
| for autogaze_batch_idx in range(num_autogaze_batches): | |
| batch_size = gaze_output['batch_sizes'][autogaze_batch_idx] | |
| start_chunk_idx = chunk_idx | |
| end_chunk_idx = chunk_idx + batch_size | |
| print(f'Reconstructing chunks {start_chunk_idx+1}-{end_chunk_idx}/{num_chunks}...') | |
| # Extract videos for all chunks in this AutoGaze batch | |
| batch_videos = video_chunks_flat[start_chunk_idx:end_chunk_idx].to(device) # (batch_size, 16, C, H, W) | |
| # Extract gazing data for all chunks in this AutoGaze batch | |
| batch_gazing_pos = gaze_output['gazing_pos'][start_chunk_idx:end_chunk_idx] | |
| batch_gazing_mask = [scale_mask[start_chunk_idx:end_chunk_idx] for scale_mask in gaze_output['gazing_mask']] | |
| batch_num_gazing_each_frame = gaze_output['num_gazing_each_frame_list'][autogaze_batch_idx] | |
| # Trim to expected sequence length for this AutoGaze batch | |
| expected_seq_len = batch_num_gazing_each_frame.sum().item() | |
| batch_gazing_pos = batch_gazing_pos[:, :expected_seq_len] | |
| chunk_idx = end_chunk_idx | |
| batch_gaze_output = { | |
| 'gazing_pos': batch_gazing_pos, | |
| 'gazing_mask': batch_gazing_mask, | |
| 'num_gazing_each_frame': batch_num_gazing_each_frame, | |
| 'frame_sampling_rate': gaze_output['frame_sampling_rate'], | |
| 'num_vision_tokens_each_frame': gaze_output['num_vision_tokens_each_frame'], | |
| } | |
| if 'if_padded_gazing' in gaze_output: | |
| batch_if_padded = gaze_output['if_padded_gazing'][start_chunk_idx:end_chunk_idx] | |
| batch_if_padded = batch_if_padded[:, :expected_seq_len] | |
| batch_gaze_output['if_padded_gazing'] = batch_if_padded | |
| # Reconstruct frame by frame for this batch | |
| batch_video_dict = {"video": batch_videos} | |
| # Pre-allocate batch_reconstructions tensor to avoid list + stack memory spike | |
| batch_reconstructions = torch.zeros((16, batch_size, C, 224, 224), device=device) | |
| for frame_idx in range(16): | |
| # Update progress for each frame | |
| frame_pct = int(((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)) * 100) | |
| if progress_callback: | |
| progress_callback(0.5 + 0.4 * ((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)), f"Reconstruction progress: {frame_pct}%") | |
| yield None | |
| task_output = task.forward_output(batch_video_dict, batch_gaze_output, frame_idx_to_reconstruct=[frame_idx]) | |
| batch_reconstructions[frame_idx] = task_output['reconstruction'][:, 0] # (recon_batch_size, C, H, W) | |
| del task_output | |
| # Reorder from (16, recon_batch_size, C, H, W) to (recon_batch_size, 16, C, H, W) to match expected chunk ordering | |
| # batch_reconstructions already in shape (16, recon_batch_size, C, H, W) | |
| batch_reconstructions = rearrange(batch_reconstructions, 't b c h w -> (b t) c h w') # (recon_batch_size * 16, C, H, W) | |
| # Write directly into pre-allocated tensor | |
| batch_size_frames = batch_reconstructions.shape[0] | |
| reconstruction_chunks[frame_idx_counter:frame_idx_counter+batch_size_frames] = batch_reconstructions.cpu() | |
| frame_idx_counter += batch_size_frames | |
| # Clean up batch-specific variables | |
| del batch_videos, batch_gaze_output, batch_video_dict, batch_reconstructions | |
| print('Reconstruction complete.') | |
| # Manually reverse the mean/std normalization to get back to [0, 1] range | |
| mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1).to(reconstruction_chunks.device) | |
| std = torch.tensor(transform.image_std).view(1, 3, 1, 1).to(reconstruction_chunks.device) | |
| reconstruction_chunks = reconstruction_chunks * std + mean | |
| # Clean up video chunks and gaze output to free GPU memory (keep gazing_masks_batched for later) | |
| del video_chunks, video_chunks_flat, gaze_output | |
| # Reshape chunks back to original structure (nt, nh, nw already calculated earlier) | |
| print(f'Reshaping reconstructed chunks back to video tensor...') | |
| reconstruction_tensor = rearrange(reconstruction_chunks, '(nt nh nw t) C h w -> (nt t) C (nh h) (nw w)', nt=nt, nh=nh, nw=nw, t=16) | |
| reconstruction_tensor = reconstruction_tensor[:T, :, :H, :W] # Remove padding | |
| # Move reconstruction to GPU for visualization | |
| reconstruction_tensor = reconstruction_tensor.to(device) | |
| gazing_mask_assembled = [] | |
| for scale_idx in range(4): | |
| scale_masks_stacked = gazing_masks_batched[scale_idx] | |
| # Reshape: (num_chunks, 16, num_patches) -> (num_chunks * 16, num_patches) | |
| scale_masks_flat = scale_masks_stacked.reshape(-1, scale_masks_stacked.shape[-1]) | |
| # Rearrange back to original video structure | |
| scale_masks_reshaped = rearrange(scale_masks_flat, '(nt nh nw t) n -> (nt t) (nh nw) n', nt=nt, nh=nh, nw=nw, t=16) | |
| scale_masks_reshaped = scale_masks_reshaped[:T] # Remove temporal padding | |
| gazing_mask_assembled.append(scale_masks_reshaped) | |
| del scale_masks_stacked, scale_masks_flat, scale_masks_reshaped | |
| del gazing_masks_batched | |
| pct = total_gazing_tokens / total_possible_tokens | |
| # Move original video to GPU for visualization | |
| video_viz = video_tensor_original.to(device) | |
| # Generate frame-by-frame visualizations | |
| original_frames = [] | |
| composite_frames = [] | |
| reconstruction_frames = [] | |
| scales_stitch_frames = [] | |
| print('Visualizing...') | |
| if progress_callback: | |
| progress_callback(0.9, "Visualizing...") | |
| yield None | |
| for t in trange(T): | |
| # Original frame | |
| frame = video_viz[t].permute(1, 2, 0) | |
| frame = torch.clip(frame, 0, 1) | |
| frame_uint8 = (frame * 255).byte().cpu().numpy() | |
| original_frames.append(frame_uint8) | |
| # Reconstruction frame | |
| recon_frame = reconstruction_tensor[t].permute(1, 2, 0) | |
| recon_frame = torch.clip(recon_frame, 0, 1) | |
| recon_uint8 = (recon_frame * 255).byte().cpu().numpy() | |
| reconstruction_frames.append(recon_uint8) | |
| composite = torch.zeros((H, W, 3)).to(device) | |
| scales = setup['scales'] | |
| alpha_values = [0.4, 0.5, 0.6, 0.7] # Per-scale opacity (coarse to fine) | |
| colors = [ | |
| [1.0, 0.0, 0.0], # Scale 0 (coarsest): Red | |
| [0.0, 1.0, 0.0], # Scale 1: Green | |
| [0.0, 0.0, 1.0], # Scale 2: Blue | |
| [1.0, 1.0, 0.0] # Scale 3 (finest): Yellow | |
| ] | |
| for scale_idx in range(4): | |
| scale = scales[scale_idx] | |
| scale_h = int(scale * H / 224) | |
| scale_w = int(scale * W / 224) | |
| # Get mask for this scale and frame | |
| mask = gazing_mask_assembled[scale_idx][t] # (nh * nw, num_patches) | |
| # print(f'Frame {t}, Scale {scale}: mask shape {mask.shape}') | |
| # print(mask) | |
| # print() | |
| # Reshape mask: (nh * nw, num_patches) where num_patches = s^2 | |
| num_patches_per_chunk = mask.shape[-1] | |
| s = int(num_patches_per_chunk ** 0.5) | |
| # Rearrange to 2D spatial grid | |
| mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s) | |
| # Convert to tensor if needed | |
| if isinstance(mask_2d, np.ndarray): | |
| mask_tensor = torch.from_numpy(mask_2d) | |
| else: | |
| mask_tensor = mask_2d | |
| # Map mask through padded space then crop to original image dimensions | |
| H_pad, W_pad = nh * 224, nw * 224 | |
| mask_full = F.interpolate(mask_tensor.unsqueeze(0).unsqueeze(0).float(), size=(H_pad, W_pad), mode='nearest')[0, 0] | |
| mask_resized = F.interpolate(mask_full[:H, :W].unsqueeze(0).unsqueeze(0), size=(scale_h, scale_w), mode='nearest')[0, 0] | |
| frame_tensor = video_viz[t] | |
| frame_scaled = F.interpolate(frame_tensor.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1) | |
| frame_scaled_masked = frame_scaled * mask_resized.unsqueeze(0) | |
| # Upsample both masked frame and mask to full size | |
| frame_upsampled = F.interpolate(frame_scaled_masked.unsqueeze(0), size=(H, W), mode='nearest').squeeze() #.cpu().numpy() | |
| mask_upsampled = F.interpolate(mask_resized.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest').squeeze() #.cpu().numpy() | |
| frame_upsampled = frame_upsampled.permute(1, 2, 0) | |
| composite = composite * (1 - mask_upsampled[:, :, None] * alpha_values[scale_idx]) + frame_upsampled * alpha_values[scale_idx] | |
| composite_np = composite.detach().cpu().numpy() | |
| composite_np = (composite_np - composite_np.min()) / (composite_np.max() - composite_np.min() + 1e-8) | |
| composite_uint8 = (composite_np * 255).astype(np.uint8) | |
| composite_frames.append(composite_uint8) | |
| # Create individual scale visualizations for horizontal stitch | |
| scale_composites = [] | |
| label_bar_height = 30 | |
| for scale_idx in range(4): | |
| scale = scales[scale_idx] | |
| scale_h = int(scale * H / 224) | |
| scale_w = int(scale * W / 224) | |
| # Get mask for this scale and frame | |
| mask = gazing_mask_assembled[scale_idx][t] | |
| # Reshape mask to 2D spatial grid | |
| num_patches_per_chunk = mask.shape[-1] | |
| s = int(num_patches_per_chunk ** 0.5) | |
| mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s) | |
| if isinstance(mask_2d, np.ndarray): | |
| mask_tensor_scale = torch.from_numpy(mask_2d) | |
| else: | |
| mask_tensor_scale = mask_2d | |
| # Map mask through padded space then crop to original image dimensions | |
| H_pad, W_pad = nh * 224, nw * 224 | |
| mask_full_scale = F.interpolate(mask_tensor_scale.unsqueeze(0).unsqueeze(0).float(), size=(H_pad, W_pad), mode='nearest')[0, 0] | |
| mask_resized_scale = F.interpolate(mask_full_scale[:H, :W].unsqueeze(0).unsqueeze(0), size=(scale_h, scale_w), mode='nearest')[0, 0] | |
| frame_tensor_scale = video_viz[t] | |
| frame_scaled_scale = F.interpolate(frame_tensor_scale.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1) | |
| # Apply gazing pattern: gazed tiles = 1.0 brightness, ungazed tiles = 0.2 brightness | |
| frame_scaled_permuted = frame_scaled_scale.permute(1, 2, 0) | |
| scale_composite = frame_scaled_permuted * (mask_resized_scale[:, :, None] * 1.0 + (1 - mask_resized_scale[:, :, None]) * 0.2) | |
| scale_composite_np = scale_composite.detach().cpu().numpy() | |
| scale_composite_np = np.clip(scale_composite_np, 0, 1) | |
| scale_composite_uint8 = (scale_composite_np * 255).astype(np.uint8) | |
| # Resize visualization to common display height first (preserving aspect ratio) | |
| display_width = int(scale_w * H / scale_h) | |
| scale_composite_pil = Image.fromarray(scale_composite_uint8) | |
| scale_composite_resized = scale_composite_pil.resize((display_width, H), Image.NEAREST) | |
| scale_composite_resized_np = np.array(scale_composite_resized) | |
| # Create label bar matching the resized visualization width | |
| label_bar = np.ones((label_bar_height, display_width, 3), dtype=np.uint8) * 255 | |
| label_bar_pil = Image.fromarray(label_bar) | |
| draw = ImageDraw.Draw(label_bar_pil) | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20) | |
| except: | |
| font = ImageFont.load_default() | |
| label = f"Scale {scale_idx + 1}" | |
| draw.text((5, 5), label, fill=(0, 0, 0), font=font) | |
| label_bar_np = np.array(label_bar_pil) | |
| # Stack label bar above the visualization | |
| scale_with_label = np.vstack([label_bar_np, scale_composite_resized_np]) | |
| scale_composites.append(scale_with_label) | |
| # Add 10px white padding between scales | |
| padding = np.ones((H + label_bar_height, 10, 3), dtype=np.uint8) * 255 | |
| # Concatenate all scales horizontally with padding | |
| stitched = scale_composites[0] | |
| for i in range(1, 4): | |
| stitched = np.concatenate([stitched, padding, scale_composites[i]], axis=1) | |
| # Add white padding at the top to prevent Gradio's label from blocking content | |
| top_padding = np.ones((50, stitched.shape[1], 3), dtype=np.uint8) * 255 | |
| stitched = np.vstack([top_padding, stitched]) | |
| scales_stitch_frames.append(stitched) | |
| del frame_tensor, mask_tensor, mask_resized, frame_scaled, frame_scaled_masked, frame_upsampled, mask_upsampled | |
| del gazing_mask_assembled | |
| del video_tensor_original, reconstruction_tensor, video_viz, reconstruction_chunks | |
| if device == 'cuda': | |
| torch.cuda.empty_cache() | |
| yield { | |
| 'original_frames': original_frames, | |
| 'gazing_frames': composite_frames, | |
| 'reconstruction_frames': reconstruction_frames, | |
| 'scales_stitch_frames': scales_stitch_frames, | |
| 'fps': fps, | |
| 'gazing_pct': pct, | |
| 'total_gazing_tokens': total_gazing_tokens, | |
| 'total_possible_tokens': total_possible_tokens | |
| } | |
| def save_video(frames, output_path, fps): | |
| with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv420p', macro_block_size=1) as writer: | |
| for frame in frames: | |
| writer.append_data(frame) | |