Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| from collections import abc | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mpl_toolkits.mplot3d import Axes3D | |
| def fps(data: torch.Tensor, number: int) -> torch.Tensor: | |
| B, N, _ = data.shape | |
| device = data.device | |
| centroids = torch.empty(B, number, dtype=torch.long, device=device) | |
| distances = torch.full((B, N), float("inf"), device=device) | |
| farthest = torch.randint(0, N, (B,), device=device) # случайная первая | |
| for i in range(number): | |
| centroids[:, i] = farthest | |
| centroid = data[torch.arange(B, device=device), farthest] # (B,3) | |
| dist = torch.sum((data - centroid[:, None, :]) ** 2, dim=-1) | |
| distances = torch.minimum(distances, dist) | |
| farthest = torch.max(distances, dim=1).indices # чуть короче | |
| # (или .indices в ≥1.10) | |
| return data.gather(1, centroids[..., None].expand(-1, -1, 3)) | |
| def worker_init_fn(worker_id): | |
| np.random.seed(np.random.get_state()[1][0] + worker_id) | |
| def build_lambda_sche(opti, config): | |
| if config.get("decay_step") is not None: | |
| def lr_lbmd(e): | |
| return max(config.lr_decay ** (e / config.decay_step), config.lowest_decay) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd) | |
| else: | |
| raise NotImplementedError() | |
| return scheduler | |
| def build_lambda_bnsche(model, config): | |
| if config.get("decay_step") is not None: | |
| def bnm_lmbd(e): | |
| return max( | |
| config.bn_momentum * config.bn_decay ** (e / config.decay_step), | |
| config.lowest_decay, | |
| ) | |
| bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd) | |
| else: | |
| raise NotImplementedError() | |
| return bnm_scheduler | |
| def set_random_seed(seed, deterministic=False): | |
| """Set random seed. | |
| Args: | |
| seed (int): Seed to be used. | |
| deterministic (bool): Whether to set the deterministic option for | |
| CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` | |
| to True and `torch.backends.cudnn.benchmark` to False. | |
| Default: False. | |
| # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html | |
| if cuda_deterministic: # slower, more reproducible | |
| cudnn.deterministic = True | |
| cudnn.benchmark = False | |
| else: # faster, less reproducible | |
| cudnn.deterministic = False | |
| cudnn.benchmark = True | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| if deterministic: | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def is_seq_of(seq, expected_type, seq_type=None): | |
| """Check whether it is a sequence of some type. | |
| Args: | |
| seq (Sequence): The sequence to be checked. | |
| expected_type (type): Expected type of sequence items. | |
| seq_type (type, optional): Expected sequence type. | |
| Returns: | |
| bool: Whether the sequence is valid. | |
| """ | |
| if seq_type is None: | |
| exp_seq_type = abc.Sequence | |
| else: | |
| assert isinstance(seq_type, type) | |
| exp_seq_type = seq_type | |
| if not isinstance(seq, exp_seq_type): | |
| return False | |
| for item in seq: | |
| if not isinstance(item, expected_type): | |
| return False | |
| return True | |
| def set_bn_momentum_default(bn_momentum): | |
| def fn(m): | |
| if isinstance(m, nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d): | |
| m.momentum = bn_momentum | |
| return fn | |
| class BNMomentumScheduler: | |
| def __init__(self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default): | |
| if not isinstance(model, nn.Module): | |
| raise RuntimeError( | |
| f"Class '{type(model).__name__}' is not a PyTorch nn Module" | |
| ) | |
| self.model = model | |
| self.setter = setter | |
| self.lmbd = bn_lambda | |
| self.step(last_epoch + 1) | |
| self.last_epoch = last_epoch | |
| def step(self, epoch=None): | |
| if epoch is None: | |
| epoch = self.last_epoch + 1 | |
| self.last_epoch = epoch | |
| self.model.apply(self.setter(self.lmbd(epoch))) | |
| def get_momentum(self, epoch=None): | |
| if epoch is None: | |
| epoch = self.last_epoch + 1 | |
| return self.lmbd(epoch) | |
| def seprate_point_cloud(xyz, num_points, crop, fixed_points=None, padding_zeros=False): | |
| """ | |
| seprate point cloud: usage : using to generate the incomplete point cloud with a setted number. | |
| """ | |
| _, n, c = xyz.shape | |
| assert n == num_points | |
| assert c == 3 | |
| if crop == num_points: | |
| return xyz, None | |
| INPUT = [] | |
| CROP = [] | |
| for points in xyz: | |
| if isinstance(crop, list): | |
| num_crop = random.randint(crop[0], crop[1]) | |
| else: | |
| num_crop = crop | |
| points = points.unsqueeze(0) | |
| if fixed_points is None: | |
| center = F.normalize(torch.randn(1, 1, 3), p=2, dim=-1).cuda() | |
| else: | |
| if isinstance(fixed_points, list): | |
| fixed_point = random.sample(fixed_points, 1)[0] | |
| else: | |
| fixed_point = fixed_points | |
| center = fixed_point.reshape(1, 1, 3).cuda() | |
| distance_matrix = torch.norm( | |
| center.unsqueeze(2) - points.unsqueeze(1), p=2, dim=-1 | |
| ) # 1 1 2048 | |
| idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0, 0] # 2048 | |
| if padding_zeros: | |
| input_data = points.clone() | |
| input_data[0, idx[:num_crop]] = input_data[0, idx[:num_crop]] * 0 | |
| else: | |
| input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3 | |
| crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0) | |
| if isinstance(crop, list): | |
| INPUT.append(fps(input_data, 2048)) | |
| CROP.append(fps(crop_data, 2048)) | |
| else: | |
| INPUT.append(input_data) | |
| CROP.append(crop_data) | |
| input_data = torch.cat(INPUT, dim=0) # B N 3 | |
| crop_data = torch.cat(CROP, dim=0) # B M 3 | |
| return input_data.contiguous(), crop_data.contiguous() | |
| def get_ptcloud_img(ptcloud, roll, pitch): | |
| fig = plt.figure(figsize=(8, 8)) | |
| x, z, y = ptcloud.transpose(1, 0) | |
| ax = fig.gca(projection=Axes3D.name, adjustable="box") | |
| ax.axis("off") | |
| # ax.axis('scaled') | |
| ax.view_init(roll, pitch) | |
| max, min = np.max(ptcloud), np.min(ptcloud) | |
| ax.set_xbound(min, max) | |
| ax.set_ybound(min, max) | |
| ax.set_zbound(min, max) | |
| ax.scatter(x, y, z, zdir="z", c=y, cmap="jet") | |
| fig.canvas.draw() | |
| img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
| img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return img | |
| def visualize_KITTI( | |
| path, | |
| data_list, | |
| titles=["input", "pred"], | |
| cmap=["bwr", "autumn"], | |
| zdir="y", | |
| xlim=(-1, 1), | |
| ylim=(-1, 1), | |
| zlim=(-1, 1), | |
| ): | |
| fig = plt.figure(figsize=(6 * len(data_list), 6)) | |
| cmax = data_list[-1][:, 0].max() | |
| for i in range(len(data_list)): | |
| data = data_list[i][:-2048] if i == 1 else data_list[i] | |
| color = data[:, 0] / cmax | |
| ax = fig.add_subplot(1, len(data_list), i + 1, projection="3d") | |
| ax.view_init(30, -120) | |
| ax.scatter( | |
| data[:, 0], | |
| data[:, 1], | |
| data[:, 2], | |
| zdir=zdir, | |
| c=color, | |
| vmin=-1, | |
| vmax=1, | |
| cmap=cmap[0], | |
| s=4, | |
| linewidth=0.05, | |
| edgecolors="black", | |
| ) | |
| ax.set_title(titles[i]) | |
| ax.set_axis_off() | |
| ax.set_xlim(xlim) | |
| ax.set_ylim(ylim) | |
| ax.set_zlim(zlim) | |
| plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0) | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| pic_path = path + ".png" | |
| fig.savefig(pic_path) | |
| np.save(os.path.join(path, "input.npy"), data_list[0].numpy()) | |
| np.save(os.path.join(path, "pred.npy"), data_list[1].numpy()) | |
| plt.close(fig) | |
| def random_dropping(pc, e): | |
| up_num = max(64, 768 // (e // 50 + 1)) | |
| pc = pc | |
| random_num = torch.randint(1, up_num, (1, 1))[0, 0] | |
| pc = fps(pc, random_num) | |
| padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device) | |
| pc = torch.cat([pc, padding], dim=1) | |
| return pc | |
| def random_scale(partial, scale_range=[0.8, 1.2]): | |
| scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0] | |
| return partial * scale | |