Spaces:
Sleeping
Sleeping
File size: 1,169 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import torch
def split_to_minibatch(batch_split, iter_context_idxs):
minibatch = {
"image": batch_split["image"][0][iter_context_idxs].unsqueeze(
0
), # [1, Vc', 3, Hc, Wc]
"extrinsics": batch_split["extrinsics"][0][iter_context_idxs].unsqueeze(
0
), # [1, Vc', 4, 4]
"intrinsics": batch_split["intrinsics"][0][iter_context_idxs].unsqueeze(
0
), # [1, Vc', 4, 4]
"near": batch_split["near"][0][iter_context_idxs].unsqueeze(0), # [1, Vc']
"far": batch_split["far"][0][iter_context_idxs].unsqueeze(0), # [1, Vc']
}
return minibatch
def batched_select(data, indices):
"""
Select data[i, indices[i]] for each batch element i.
Args:
data: [B, N, ...] input tensor
indices: [B, K] indices for each batch element
"""
assert data.shape[0] == indices.shape[0], f"Batch size mismatch {data.shape[0]} vs {indices.shape[0]}"
assert indices.dim() == 2, f"indices should be 2D, got {indices.shape}"
B = data.shape[0]
batch_idx = torch.arange(B, device=data.device)[:, None]
return data[batch_idx, indices]
|