| |
| try: |
| import spaces |
| ZEROGPU_AVAILABLE = True |
| except ImportError: |
| ZEROGPU_AVAILABLE = False |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import av |
| import imageio |
| from transformers import VivitImageProcessor |
| from PIL import Image, ImageDraw, ImageFont |
| from omegaconf import OmegaConf |
| from einops import rearrange |
| from tqdm import trange |
| from autogaze.models.autogaze import AutoGaze |
| from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch |
| from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction |
| from autogaze.utils import UnNormalize |
|
|
|
|
| def image_to_video(image_path, output_path, fps): |
| """ |
| Convert a single image to a single-frame video file. |
| |
| Args: |
| image_path: Path to input image |
| output_path: Path to output video file |
| fps: Frame rate for the video |
| |
| Returns: |
| Dictionary with video metadata (width, height, frames, fps) |
| """ |
| img = Image.open(image_path) |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
|
|
| img_array = np.array(img) |
|
|
| with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv444p', macro_block_size=1) as writer: |
| writer.append_data(img_array) |
|
|
| return { |
| 'width': img_array.shape[1], |
| 'height': img_array.shape[0], |
| 'frames': 1, |
| 'fps': fps |
| } |
|
|
|
|
| def load_model(device='cuda'): |
| print("Loading AutoGaze model from HuggingFace...") |
| model = AutoGaze.from_pretrained("nvidia/AutoGaze") |
| model = model.to(device) |
| model.eval() |
|
|
| transform = VivitImageProcessor.from_pretrained( |
| "facebook/vit-mae-large", |
| size=model.scales[-1], |
| crop_size=model.scales[-1] |
| ) |
|
|
| unnorm = UnNormalize( |
| mean=transform.image_mean, |
| std=transform.image_std, |
| rescale_factor=transform.rescale_factor |
| ) |
|
|
| print("Loading VideoMAE model from HuggingFace...") |
| scales_str = '+'.join(map(str, model.scales)) |
| recon_model_config = OmegaConf.create({ |
| 'scale_embed': True, |
| 'max_num_frames': 256, |
| 'time_embed': True, |
| 'causal': True, |
| 'loss_type': 'l1+dinov2_reg+siglip2', |
| 'loss_weights': '1', |
| 'l1_loss_config': {}, |
| 'dinov2_reg_loss_config': { |
| 'model': 'facebook/dinov2-with-registers-base' |
| }, |
| 'siglip2_loss_config': { |
| 'model': 'google/siglip2-base-patch16-224' |
| } |
| }) |
| task = VideoMAEReconstruction( |
| recon_model='facebook/vit-mae-large', |
| recon_model_config=recon_model_config, |
| scales=scales_str, |
| recon_sample_rate=1, |
| attn_mode='sdpa' |
| ) |
|
|
| |
| from huggingface_hub import hf_hub_download |
| checkpoint_path = hf_hub_download(repo_id="bfshi/VideoMAE_AutoGaze", filename="videomae.pt") |
| print(f"Loading VideoMAE checkpoint from {checkpoint_path}...") |
| task_sd = torch.load(checkpoint_path, map_location='cpu') |
| task_sd = {k.replace('module.mae.', ''): v for k, v in task_sd.items()} |
| task.mae.load_state_dict(task_sd, strict=True) |
| print("Loaded VideoMAE checkpoint from HuggingFace") |
|
|
| task = task.to(device) |
| task.eval() |
|
|
| return { |
| 'model': model, |
| 'task': task, |
| 'unnorm': unnorm, |
| 'scales': model.scales, |
| 'transform': transform, |
| } |
|
|
|
|
| def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.6, progress_callback=None, spatial_batch_size=16): |
| """ |
| Process a video file with AutoGaze using chunking for any resolution/duration. |
| |
| Args: |
| video_path: Path to video file |
| setup: Dictionary with model, task, unnorm, scales, transform |
| gazing_ratio: Maximum percentage of patches to gaze per frame |
| task_loss_requirement: Reconstruction loss threshold |
| progress_callback: Optional callback function for progress updates |
| |
| Yields: |
| Dictionary with original frames, gazing frames, reconstruction frames, and statistics |
| """ |
| model = setup['model'] |
| task = setup['task'] |
| transform = setup['transform'] |
| device = next(model.parameters()).device |
| if device == 'cuda': |
| torch.cuda.empty_cache() |
|
|
| container = av.open(video_path) |
| video_stream = container.streams.video[0] |
| total_frames_available = video_stream.frames |
| fps = float(video_stream.average_rate) |
| container.close() |
|
|
| container = av.open(video_path) |
| sample_indices = list(range(total_frames_available)) |
| video = read_video_pyav(container=container, indices=sample_indices) |
| container.close() |
|
|
| |
| video_tensor = torch.from_numpy(video).float() |
| video_tensor = video_tensor / 255.0 |
| video_tensor = video_tensor.permute(0, 3, 1, 2) |
| T, C, H, W = video_tensor.shape |
| if T > 200: |
| print(f'Video has {T} frames, which may require significant GPU memory. Decreasing spatial_batch_size to 2.') |
| spatial_batch_size //= 2 |
|
|
| |
| video_tensor_original = video_tensor.clone() |
|
|
| |
| pad_t = (16 - T % 16) % 16 |
| pad_h = (224 - H % 224) % 224 |
| pad_w = (224 - W % 224) % 224 |
|
|
| if pad_t > 0 or pad_h > 0 or pad_w > 0: |
| video_tensor = F.pad(video_tensor, (0, pad_w, 0, pad_h, 0, 0, 0, pad_t)) |
|
|
| |
| video_tensor = video_tensor.unsqueeze(0) |
|
|
| |
| nt = (T + pad_t) // 16 |
| nh = (H + pad_h) // 224 |
| nw = (W + pad_w) // 224 |
| num_spatial_chunks = nh * nw |
| num_chunks = nt * num_spatial_chunks |
|
|
| |
| video_chunks = rearrange(video_tensor, 'B (nt t) C (nh h) (nw w) -> (B nt nh nw) t C h w', t=16, h=224, w=224) |
|
|
| print(f"Video chunked into {num_chunks} chunks ({nt} temporal x {num_spatial_chunks} spatial) of shape (16, {C}, 224, 224). Original shape: ({T}, {C}, {H}, {W})") |
|
|
| |
| |
| chunks_flat = rearrange(video_chunks, 'b t c h w -> (b t) c h w') |
|
|
| |
| mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1) |
| std = torch.tensor(transform.image_std).view(1, 3, 1, 1) |
| chunks_flat = (chunks_flat - mean) / std |
|
|
| video_chunks = rearrange(chunks_flat, '(b t) c h w -> b t c h w', b=num_chunks, t=16) |
| video_chunks = rearrange(video_chunks, '(ns nt) t c h w -> ns nt t c h w', ns=num_spatial_chunks, nt=nt) |
|
|
| |
| print(f'video_chunks shape (spatial, temporal, frames, C, H, W): {video_chunks.shape}') |
|
|
| del video_tensor, chunks_flat, mean, std |
|
|
| with torch.inference_mode(): |
| |
| num_spatial_batches = (num_spatial_chunks + spatial_batch_size - 1) // spatial_batch_size |
|
|
| all_gaze_outputs = [] |
| total_gazing_tokens = 0 |
|
|
| for batch_idx in range(num_spatial_batches): |
| start_idx = batch_idx * spatial_batch_size |
| end_idx = min(start_idx + spatial_batch_size, num_spatial_chunks) |
| batch_size = end_idx - start_idx |
|
|
| gazing_pct = int(((batch_idx + 1) / num_spatial_batches) * 100) |
| if progress_callback: |
| progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%") |
| yield None |
|
|
| spatial_batch = video_chunks[start_idx:end_idx].to(device) |
| spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w') |
| print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks') |
|
|
| |
| batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement) |
|
|
| num_gazing_each_frame = batch_gaze_output['num_gazing_each_frame'][:T] |
| num_gazing_total = num_gazing_each_frame.sum().item() |
|
|
| |
| del spatial_batch |
|
|
| |
| if_padded = batch_gaze_output.get('if_padded_gazing') |
| if if_padded is not None: |
| print(f'shape of if_padded: {if_padded.shape}') |
| if_padded = if_padded[:, :min(num_gazing_total, if_padded.shape[1])] |
| new_gazing_tokens = (~if_padded).sum().item() |
| else: |
| new_gazing_tokens = (batch_gaze_output['gazing_pos'] < (196 * T)).sum().item() |
| total_gazing_tokens += new_gazing_tokens |
| print(f'Batch {batch_idx+1}: Gazing tokens = {new_gazing_tokens}, Total gazing tokens so far = {total_gazing_tokens}') |
|
|
| |
| all_gaze_outputs.append(batch_gaze_output) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| print("Merging mini-batch results...") |
|
|
| |
| max_seq_len = max(out['gazing_pos'].shape[1] for out in all_gaze_outputs) |
|
|
| |
| |
| padded_gazing_pos = [] |
| padded_if_padded_gazing = [] |
|
|
| for out in all_gaze_outputs: |
| seq_len = out['gazing_pos'].shape[1] |
| pad_len = max_seq_len - seq_len |
|
|
| |
| padded_pos = F.pad(out['gazing_pos'], (0, pad_len), value=0) |
| padded_gazing_pos.append(padded_pos) |
|
|
| |
| if 'if_padded_gazing' in out: |
| padded_if_pad = F.pad(out['if_padded_gazing'], (0, pad_len), value=True) |
| padded_if_padded_gazing.append(padded_if_pad) |
|
|
| |
| num_gazing_each_frame_list = [out['num_gazing_each_frame'] for out in all_gaze_outputs] |
| batch_sizes = [out['gazing_pos'].shape[0] for out in all_gaze_outputs] |
|
|
| gaze_output = { |
| 'gazing_pos': torch.cat(padded_gazing_pos, dim=0), |
| 'gazing_mask': [torch.cat([out['gazing_mask'][i] for out in all_gaze_outputs], dim=0) for i in range(4)], |
| 'num_gazing_each_frame_list': num_gazing_each_frame_list, |
| 'batch_sizes': batch_sizes, |
| 'frame_sampling_rate': all_gaze_outputs[0]['frame_sampling_rate'], |
| 'num_vision_tokens_each_frame': all_gaze_outputs[0]['num_vision_tokens_each_frame'], |
| } |
| if len(padded_if_padded_gazing) > 0: |
| gaze_output['if_padded_gazing'] = torch.cat(padded_if_padded_gazing, dim=0) |
|
|
| |
| del all_gaze_outputs |
|
|
| total_possible_tokens = 196 * min(T, 16) * num_chunks |
|
|
| |
| gazing_masks_batched = gaze_output['gazing_mask'] |
|
|
| |
| video_chunks_flat = rearrange(video_chunks, 'ns nt t c h w -> (ns nt) t c h w').cpu() |
|
|
| |
| total_frames = num_chunks * 16 |
| C = video_chunks_flat.shape[2] |
| reconstruction_chunks = torch.zeros((total_frames, C, 224, 224), dtype=torch.float32) |
| frame_idx_counter = 0 |
|
|
| |
| num_autogaze_batches = len(gaze_output['num_gazing_each_frame_list']) |
| print(f'Reconstructing {num_chunks} chunks in {num_autogaze_batches} batches (aligned with AutoGaze batches)...') |
|
|
| chunk_idx = 0 |
| for autogaze_batch_idx in range(num_autogaze_batches): |
| batch_size = gaze_output['batch_sizes'][autogaze_batch_idx] |
| start_chunk_idx = chunk_idx |
| end_chunk_idx = chunk_idx + batch_size |
|
|
| print(f'Reconstructing chunks {start_chunk_idx+1}-{end_chunk_idx}/{num_chunks}...') |
|
|
| |
| batch_videos = video_chunks_flat[start_chunk_idx:end_chunk_idx].to(device) |
|
|
| |
| batch_gazing_pos = gaze_output['gazing_pos'][start_chunk_idx:end_chunk_idx] |
| batch_gazing_mask = [scale_mask[start_chunk_idx:end_chunk_idx] for scale_mask in gaze_output['gazing_mask']] |
| batch_num_gazing_each_frame = gaze_output['num_gazing_each_frame_list'][autogaze_batch_idx] |
|
|
| |
| expected_seq_len = batch_num_gazing_each_frame.sum().item() |
| batch_gazing_pos = batch_gazing_pos[:, :expected_seq_len] |
|
|
| chunk_idx = end_chunk_idx |
|
|
| batch_gaze_output = { |
| 'gazing_pos': batch_gazing_pos, |
| 'gazing_mask': batch_gazing_mask, |
| 'num_gazing_each_frame': batch_num_gazing_each_frame, |
| 'frame_sampling_rate': gaze_output['frame_sampling_rate'], |
| 'num_vision_tokens_each_frame': gaze_output['num_vision_tokens_each_frame'], |
| } |
|
|
| if 'if_padded_gazing' in gaze_output: |
| batch_if_padded = gaze_output['if_padded_gazing'][start_chunk_idx:end_chunk_idx] |
| batch_if_padded = batch_if_padded[:, :expected_seq_len] |
| batch_gaze_output['if_padded_gazing'] = batch_if_padded |
|
|
| |
| batch_video_dict = {"video": batch_videos} |
| |
| batch_reconstructions = torch.zeros((16, batch_size, C, 224, 224), device=device) |
| for frame_idx in range(16): |
| |
| frame_pct = int(((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)) * 100) |
| if progress_callback: |
| progress_callback(0.5 + 0.4 * ((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)), f"Reconstruction progress: {frame_pct}%") |
| yield None |
|
|
| task_output = task.forward_output(batch_video_dict, batch_gaze_output, frame_idx_to_reconstruct=[frame_idx]) |
| batch_reconstructions[frame_idx] = task_output['reconstruction'][:, 0] |
| del task_output |
|
|
| |
| |
| batch_reconstructions = rearrange(batch_reconstructions, 't b c h w -> (b t) c h w') |
|
|
| |
| batch_size_frames = batch_reconstructions.shape[0] |
| reconstruction_chunks[frame_idx_counter:frame_idx_counter+batch_size_frames] = batch_reconstructions.cpu() |
| frame_idx_counter += batch_size_frames |
|
|
| |
| del batch_videos, batch_gaze_output, batch_video_dict, batch_reconstructions |
| print('Reconstruction complete.') |
| |
| mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1).to(reconstruction_chunks.device) |
| std = torch.tensor(transform.image_std).view(1, 3, 1, 1).to(reconstruction_chunks.device) |
| reconstruction_chunks = reconstruction_chunks * std + mean |
|
|
| |
| del video_chunks, video_chunks_flat, gaze_output |
|
|
| |
| print(f'Reshaping reconstructed chunks back to video tensor...') |
| reconstruction_tensor = rearrange(reconstruction_chunks, '(nt nh nw t) C h w -> (nt t) C (nh h) (nw w)', nt=nt, nh=nh, nw=nw, t=16) |
| reconstruction_tensor = reconstruction_tensor[:T, :, :H, :W] |
|
|
| |
| reconstruction_tensor = reconstruction_tensor.to(device) |
|
|
| gazing_mask_assembled = [] |
| for scale_idx in range(4): |
| scale_masks_stacked = gazing_masks_batched[scale_idx] |
|
|
| |
| scale_masks_flat = scale_masks_stacked.reshape(-1, scale_masks_stacked.shape[-1]) |
|
|
| |
| scale_masks_reshaped = rearrange(scale_masks_flat, '(nt nh nw t) n -> (nt t) (nh nw) n', nt=nt, nh=nh, nw=nw, t=16) |
| scale_masks_reshaped = scale_masks_reshaped[:T] |
|
|
| gazing_mask_assembled.append(scale_masks_reshaped) |
|
|
| del scale_masks_stacked, scale_masks_flat, scale_masks_reshaped |
|
|
| del gazing_masks_batched |
|
|
| pct = total_gazing_tokens / total_possible_tokens |
|
|
| |
| video_viz = video_tensor_original.to(device) |
|
|
| |
| original_frames = [] |
| composite_frames = [] |
| reconstruction_frames = [] |
| scales_stitch_frames = [] |
|
|
| print('Visualizing...') |
| if progress_callback: |
| progress_callback(0.9, "Visualizing...") |
| yield None |
| for t in trange(T): |
| |
| frame = video_viz[t].permute(1, 2, 0) |
| frame = torch.clip(frame, 0, 1) |
| frame_uint8 = (frame * 255).byte().cpu().numpy() |
| original_frames.append(frame_uint8) |
|
|
| |
| recon_frame = reconstruction_tensor[t].permute(1, 2, 0) |
| recon_frame = torch.clip(recon_frame, 0, 1) |
| recon_uint8 = (recon_frame * 255).byte().cpu().numpy() |
| reconstruction_frames.append(recon_uint8) |
|
|
| composite = torch.zeros((H, W, 3)).to(device) |
| scales = setup['scales'] |
| alpha_values = [0.4, 0.5, 0.6, 0.7] |
| colors = [ |
| [1.0, 0.0, 0.0], |
| [0.0, 1.0, 0.0], |
| [0.0, 0.0, 1.0], |
| [1.0, 1.0, 0.0] |
| ] |
|
|
| for scale_idx in range(4): |
| scale = scales[scale_idx] |
| scale_h = int(scale * H / 224) |
| scale_w = int(scale * W / 224) |
|
|
| |
| mask = gazing_mask_assembled[scale_idx][t] |
|
|
| |
| |
| |
|
|
| |
| num_patches_per_chunk = mask.shape[-1] |
| s = int(num_patches_per_chunk ** 0.5) |
|
|
| |
| mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s) |
|
|
| |
| if isinstance(mask_2d, np.ndarray): |
| mask_tensor = torch.from_numpy(mask_2d) |
| else: |
| mask_tensor = mask_2d |
|
|
| |
| H_pad, W_pad = nh * 224, nw * 224 |
| mask_full = F.interpolate(mask_tensor.unsqueeze(0).unsqueeze(0).float(), size=(H_pad, W_pad), mode='nearest')[0, 0] |
| mask_resized = F.interpolate(mask_full[:H, :W].unsqueeze(0).unsqueeze(0), size=(scale_h, scale_w), mode='nearest')[0, 0] |
|
|
| frame_tensor = video_viz[t] |
| frame_scaled = F.interpolate(frame_tensor.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1) |
|
|
| frame_scaled_masked = frame_scaled * mask_resized.unsqueeze(0) |
|
|
| |
| frame_upsampled = F.interpolate(frame_scaled_masked.unsqueeze(0), size=(H, W), mode='nearest').squeeze() |
| mask_upsampled = F.interpolate(mask_resized.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest').squeeze() |
|
|
| frame_upsampled = frame_upsampled.permute(1, 2, 0) |
|
|
| composite = composite * (1 - mask_upsampled[:, :, None] * alpha_values[scale_idx]) + frame_upsampled * alpha_values[scale_idx] |
|
|
| composite_np = composite.detach().cpu().numpy() |
| composite_np = (composite_np - composite_np.min()) / (composite_np.max() - composite_np.min() + 1e-8) |
| composite_uint8 = (composite_np * 255).astype(np.uint8) |
| composite_frames.append(composite_uint8) |
|
|
| |
| scale_composites = [] |
| label_bar_height = 30 |
|
|
| for scale_idx in range(4): |
| scale = scales[scale_idx] |
| scale_h = int(scale * H / 224) |
| scale_w = int(scale * W / 224) |
|
|
| |
| mask = gazing_mask_assembled[scale_idx][t] |
|
|
| |
| num_patches_per_chunk = mask.shape[-1] |
| s = int(num_patches_per_chunk ** 0.5) |
| mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s) |
|
|
| if isinstance(mask_2d, np.ndarray): |
| mask_tensor_scale = torch.from_numpy(mask_2d) |
| else: |
| mask_tensor_scale = mask_2d |
|
|
| |
| H_pad, W_pad = nh * 224, nw * 224 |
| mask_full_scale = F.interpolate(mask_tensor_scale.unsqueeze(0).unsqueeze(0).float(), size=(H_pad, W_pad), mode='nearest')[0, 0] |
| mask_resized_scale = F.interpolate(mask_full_scale[:H, :W].unsqueeze(0).unsqueeze(0), size=(scale_h, scale_w), mode='nearest')[0, 0] |
|
|
| frame_tensor_scale = video_viz[t] |
| frame_scaled_scale = F.interpolate(frame_tensor_scale.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1) |
|
|
| |
| frame_scaled_permuted = frame_scaled_scale.permute(1, 2, 0) |
| scale_composite = frame_scaled_permuted * (mask_resized_scale[:, :, None] * 1.0 + (1 - mask_resized_scale[:, :, None]) * 0.2) |
|
|
| scale_composite_np = scale_composite.detach().cpu().numpy() |
| scale_composite_np = np.clip(scale_composite_np, 0, 1) |
| scale_composite_uint8 = (scale_composite_np * 255).astype(np.uint8) |
|
|
| |
| display_width = int(scale_w * H / scale_h) |
| scale_composite_pil = Image.fromarray(scale_composite_uint8) |
| scale_composite_resized = scale_composite_pil.resize((display_width, H), Image.NEAREST) |
| scale_composite_resized_np = np.array(scale_composite_resized) |
|
|
| |
| label_bar = np.ones((label_bar_height, display_width, 3), dtype=np.uint8) * 255 |
| label_bar_pil = Image.fromarray(label_bar) |
| draw = ImageDraw.Draw(label_bar_pil) |
| try: |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20) |
| except: |
| font = ImageFont.load_default() |
|
|
| label = f"Scale {scale_idx + 1}" |
| draw.text((5, 5), label, fill=(0, 0, 0), font=font) |
| label_bar_np = np.array(label_bar_pil) |
|
|
| |
| scale_with_label = np.vstack([label_bar_np, scale_composite_resized_np]) |
|
|
| scale_composites.append(scale_with_label) |
|
|
| |
| padding = np.ones((H + label_bar_height, 10, 3), dtype=np.uint8) * 255 |
|
|
| |
| stitched = scale_composites[0] |
| for i in range(1, 4): |
| stitched = np.concatenate([stitched, padding, scale_composites[i]], axis=1) |
|
|
| |
| top_padding = np.ones((50, stitched.shape[1], 3), dtype=np.uint8) * 255 |
| stitched = np.vstack([top_padding, stitched]) |
|
|
| scales_stitch_frames.append(stitched) |
|
|
| del frame_tensor, mask_tensor, mask_resized, frame_scaled, frame_scaled_masked, frame_upsampled, mask_upsampled |
|
|
| del gazing_mask_assembled |
|
|
| del video_tensor_original, reconstruction_tensor, video_viz, reconstruction_chunks |
|
|
| if device == 'cuda': |
| torch.cuda.empty_cache() |
|
|
| yield { |
| 'original_frames': original_frames, |
| 'gazing_frames': composite_frames, |
| 'reconstruction_frames': reconstruction_frames, |
| 'scales_stitch_frames': scales_stitch_frames, |
| 'fps': fps, |
| 'gazing_pct': pct, |
| 'total_gazing_tokens': total_gazing_tokens, |
| 'total_possible_tokens': total_possible_tokens |
| } |
|
|
|
|
| def save_video(frames, output_path, fps): |
| with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv444p', macro_block_size=1) as writer: |
| for frame in frames: |
| writer.append_data(frame) |
|
|