zoo3d / vggt /utils /helper.py
drozdgk's picture
init
4eeefd1
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray:
"""
If mask has more than max_trues True values,
randomly keep only max_trues of them and set the rest to False.
"""
# 1D positions of all True entries
true_indices = np.flatnonzero(mask) # shape = (N_true,)
# if already within budget, return as-is
if true_indices.size <= max_trues:
return mask
# randomly pick which True positions to keep
sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,)
# build new flat mask: True only at sampled positions
limited_flat_mask = np.zeros(mask.size, dtype=bool)
limited_flat_mask[sampled_indices] = True
# restore original shape
return limited_flat_mask.reshape(mask.shape)
def create_pixel_coordinate_grid(num_frames, height, width):
"""
Creates a grid of pixel coordinates and frame indices for all frames.
Returns:
tuple: A tuple containing:
- points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
with x, y coordinates and frame indices
- y_coords (numpy.ndarray): Array of y coordinates for all frames
- x_coords (numpy.ndarray): Array of x coordinates for all frames
- f_coords (numpy.ndarray): Array of frame indices for all frames
"""
# Create coordinate grids for a single frame
y_grid, x_grid = np.indices((height, width), dtype=np.float32)
x_grid = x_grid[np.newaxis, :, :]
y_grid = y_grid[np.newaxis, :, :]
# Broadcast to all frames
x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
# Create frame indices and broadcast
f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
# Stack coordinates and frame indices
points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
return points_xyf