Spaces:
Sleeping
Sleeping
| import torch | |
| def split_to_minibatch(batch_split, iter_context_idxs): | |
| minibatch = { | |
| "image": batch_split["image"][0][iter_context_idxs].unsqueeze( | |
| 0 | |
| ), # [1, Vc', 3, Hc, Wc] | |
| "extrinsics": batch_split["extrinsics"][0][iter_context_idxs].unsqueeze( | |
| 0 | |
| ), # [1, Vc', 4, 4] | |
| "intrinsics": batch_split["intrinsics"][0][iter_context_idxs].unsqueeze( | |
| 0 | |
| ), # [1, Vc', 4, 4] | |
| "near": batch_split["near"][0][iter_context_idxs].unsqueeze(0), # [1, Vc'] | |
| "far": batch_split["far"][0][iter_context_idxs].unsqueeze(0), # [1, Vc'] | |
| } | |
| return minibatch | |
| def batched_select(data, indices): | |
| """ | |
| Select data[i, indices[i]] for each batch element i. | |
| Args: | |
| data: [B, N, ...] input tensor | |
| indices: [B, K] indices for each batch element | |
| """ | |
| assert data.shape[0] == indices.shape[0], f"Batch size mismatch {data.shape[0]} vs {indices.shape[0]}" | |
| assert indices.dim() == 2, f"indices should be 2D, got {indices.shape}" | |
| B = data.shape[0] | |
| batch_idx = torch.arange(B, device=data.device)[:, None] | |
| return data[batch_idx, indices] | |