# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: """ If mask has more than max_trues True values, randomly keep only max_trues of them and set the rest to False. """ # 1D positions of all True entries true_indices = np.flatnonzero(mask) # shape = (N_true,) # if already within budget, return as-is if true_indices.size <= max_trues: return mask # randomly pick which True positions to keep sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) # build new flat mask: True only at sampled positions limited_flat_mask = np.zeros(mask.size, dtype=bool) limited_flat_mask[sampled_indices] = True # restore original shape return limited_flat_mask.reshape(mask.shape) def create_pixel_coordinate_grid(num_frames, height, width): """ Creates a grid of pixel coordinates and frame indices for all frames. Returns: tuple: A tuple containing: - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) with x, y coordinates and frame indices - y_coords (numpy.ndarray): Array of y coordinates for all frames - x_coords (numpy.ndarray): Array of x coordinates for all frames - f_coords (numpy.ndarray): Array of frame indices for all frames """ # Create coordinate grids for a single frame y_grid, x_grid = np.indices((height, width), dtype=np.float32) x_grid = x_grid[np.newaxis, :, :] y_grid = y_grid[np.newaxis, :, :] # Broadcast to all frames x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) # Create frame indices and broadcast f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) # Stack coordinates and frame indices points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) return points_xyf