| import torch |
| from pointops import knn_query, ball_query, grouping |
|
|
|
|
| def knn_query_and_group( |
| feat, |
| xyz, |
| offset=None, |
| new_xyz=None, |
| new_offset=None, |
| idx=None, |
| nsample=None, |
| with_xyz=False, |
| ): |
| if idx is None: |
| assert nsample is not None |
| idx, _ = knn_query(nsample, xyz, offset, new_xyz, new_offset) |
| return grouping(idx, feat, xyz, new_xyz, with_xyz), idx |
|
|
|
|
| def ball_query_and_group( |
| feat, |
| xyz, |
| offset=None, |
| new_xyz=None, |
| new_offset=None, |
| idx=None, |
| max_radio=None, |
| min_radio=0, |
| nsample=None, |
| with_xyz=False, |
| ): |
| if idx is None: |
| assert nsample is not None and offset is not None |
| assert max_radio is not None and min_radio is not None |
| idx, _ = ball_query( |
| nsample, max_radio, min_radio, xyz, offset, new_xyz, new_offset |
| ) |
| return grouping(idx, feat, xyz, new_xyz, with_xyz), idx |
|
|
|
|
| def query_and_group( |
| nsample, |
| xyz, |
| new_xyz, |
| feat, |
| idx, |
| offset, |
| new_offset, |
| dilation=0, |
| with_feat=True, |
| with_xyz=True, |
| ): |
| """ |
| input: coords: (n, 3), new_xyz: (m, 3), color: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) |
| output: new_feat: (m, nsample, c+3), grouped_idx: (m, nsample) |
| """ |
| assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() |
| if new_xyz is None: |
| new_xyz = xyz |
|
|
| if idx is None: |
| num_samples_total = 1 + (nsample - 1) * (dilation + 1) |
| |
| idx_no_dilation, _ = knn_query( |
| num_samples_total, xyz, offset, new_xyz, new_offset |
| ) |
| idx = [] |
| batch_end = offset.tolist() |
| batch_start = [0] + batch_end[:-1] |
| new_batch_end = new_offset.tolist() |
| new_batch_start = [0] + new_batch_end[:-1] |
| for i in range(offset.shape[0]): |
| if batch_end[i] - batch_start[i] < num_samples_total: |
| soft_dilation = (batch_end[i] - batch_start[i] - 1) / (nsample - 1) - 1 |
| else: |
| soft_dilation = dilation |
| idx.append( |
| idx_no_dilation[ |
| new_batch_start[i] : new_batch_end[i], |
| [int((soft_dilation + 1) * i) for i in range(nsample)], |
| ] |
| ) |
| idx = torch.cat(idx, dim=0) |
|
|
| if not with_feat: |
| return idx |
|
|
| n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] |
| grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) |
| |
| grouped_xyz -= new_xyz.unsqueeze(1) |
| grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) |
| |
|
|
| if with_xyz: |
| return torch.cat((grouped_xyz, grouped_feat), -1), idx |
| else: |
| return grouped_feat, idx |
|
|
|
|
| def offset2batch(offset): |
| return ( |
| torch.cat( |
| [ |
| ( |
| torch.tensor([i] * (o - offset[i - 1])) |
| if i > 0 |
| else torch.tensor([i] * o) |
| ) |
| for i, o in enumerate(offset) |
| ], |
| dim=0, |
| ) |
| .long() |
| .to(offset.device) |
| ) |
|
|
|
|
| def batch2offset(batch): |
| return torch.cumsum(batch.bincount(), dim=0).int() |
|
|