Spaces:
Paused
Paused
| import numpy as np | |
| import torch | |
| from torch import optim | |
| from torch import nn | |
| from tqdm import tqdm | |
| from matplotlib import pyplot as plt | |
| import torch.nn.functional as F | |
| from collections import defaultdict | |
| import os | |
| import roma | |
| class unicycle(torch.nn.Module): | |
| def __init__(self, train_timestamp, centers, eulers, heights=None): | |
| super(unicycle, self).__init__() | |
| self.train_timestamp = train_timestamp | |
| self.delta = torch.diff(self.train_timestamp) | |
| self.input_a = centers[:, 0].clone() | |
| self.input_b = centers[:, 1].clone() | |
| self.a = nn.Parameter(centers[:, 0]) | |
| self.b = nn.Parameter(centers[:, 1]) | |
| diff_a = torch.diff(centers[:, 0]) / self.delta | |
| diff_b = torch.diff(centers[:, 1]) / self.delta | |
| v = torch.sqrt(diff_a ** 2 + diff_b**2) | |
| self.v = nn.Parameter(F.pad(v, (0, 1), 'constant', v[-1].item())) | |
| self.pitchroll = eulers[:, :2] | |
| self.yaw = nn.Parameter(eulers[:, -1]) | |
| if heights is None: | |
| self.h = torch.zeros_like(train_timestamp).float() | |
| else: | |
| self.h = heights | |
| def acc_omega(self): | |
| acc = torch.diff(self.v) / self.delta | |
| omega = torch.diff(self.yaw) / self.delta | |
| acc = F.pad(acc, (0, 1), 'constant', acc[-1].item()) | |
| omega = F.pad(omega, (0, 1), 'constant', omega[-1].item()) | |
| return acc, omega | |
| def forward(self, timestamp): | |
| if timestamp < self.train_timestamp[0]: | |
| delta_t = self.train_timestamp[0] - timestamp | |
| a = self.a[0] - delta_t * torch.cos(self.yaw[0]) * self.v[0] | |
| b = self.b[0] - delta_t * torch.sin(self.yaw[0]) * self.v[0] | |
| return a, b, self.v[0], self.pitchroll[0], self.yaw[0], self.h[0] | |
| elif timestamp > self.train_timestamp[-1]: | |
| delta_t = timestamp - self.train_timestamp[-1] | |
| a = self.a[-1] + delta_t * torch.cos(self.yaw[-1]) * self.v[-1] | |
| b = self.b[-1] + delta_t * torch.sin(self.yaw[-1]) * self.v[-1] | |
| return a, b, self.v[-1], self.pitchroll[-1], self.yaw[-1], self.h[-1] | |
| idx = torch.searchsorted(self.train_timestamp, timestamp, side='left') | |
| if self.train_timestamp[idx] == timestamp: | |
| return self.a[idx], self.b[idx], self.v[idx], self.pitchroll[idx], self.yaw[idx], self.h[idx] | |
| else: | |
| prev_timestamps = self.train_timestamp[idx-1] | |
| delta_t = timestamp - prev_timestamps | |
| prev_a, prev_b = self.a[idx-1], self.b[idx-1] | |
| prev_v, prev_yaw = self.v[idx-1], self.yaw[idx-1] | |
| acc, omega = self.acc_omega() | |
| v = prev_v + acc[idx-1] * delta_t | |
| yaw = prev_yaw + omega[idx-1] * delta_t | |
| a = prev_a + prev_v * ((torch.sin(yaw) - torch.sin(prev_yaw)) / (omega[idx-1] + 1e-6)) | |
| b = prev_b - prev_v * ((torch.cos(yaw) - torch.cos(prev_yaw)) / (omega[idx-1] + 1e-6)) | |
| h = self.h[idx-1] | |
| return a, b, v, self.pitchroll[idx-1], yaw, h | |
| def capture(self): | |
| return ( | |
| self.a, | |
| self.b, | |
| self.v, | |
| self.pitchroll, | |
| self.yaw, | |
| self.h, | |
| self.train_timestamp, | |
| self.delta | |
| ) | |
| def restore(cls, model_args): | |
| ( | |
| a, | |
| b, | |
| v, | |
| pitchroll, | |
| yaw, | |
| h, | |
| train_timestamp, | |
| delta | |
| ) = model_args | |
| model = cls(train_timestamp, | |
| torch.stack([a.clone().detach(),b.clone().detach()], dim=-1), | |
| torch.concat([pitchroll.clone().detach(), yaw.clone().detach()[:, None]], dim=-1), | |
| h.clone().detach()) | |
| model.a = a | |
| model.b = b | |
| model.v = v | |
| model.pitchroll = pitchroll | |
| model.yaw = yaw | |
| model.h = h | |
| model.train_timestamp = train_timestamp | |
| model.delta = delta | |
| return model | |
| def visualize(self, save_path, noise_centers=None, gt_centers=None): | |
| a = self.a.detach().cpu().numpy() | |
| b = self.b.detach().cpu().numpy() | |
| yaw = self.yaw.detach().cpu().numpy() | |
| plt.scatter(a, b, marker='x', color='b') | |
| plt.quiver(a, b, np.ones_like(a) * np.cos(yaw), np.ones_like(b) * np.sin(yaw), scale=20, width=0.005) | |
| if noise_centers is not None: | |
| noise_centers = noise_centers.detach().cpu().numpy() | |
| plt.scatter(noise_centers[:, 0], noise_centers[:, 1], marker='o', color='gray') | |
| if gt_centers is not None: | |
| gt_centers = gt_centers.detach().cpu().numpy() | |
| plt.scatter(gt_centers[:, 0], gt_centers[:, 1], marker='v', color='g') | |
| plt.axis('equal') | |
| plt.savefig(save_path) | |
| plt.close() | |
| def reg_loss(self): | |
| reg = 0 | |
| acc, omega = self.acc_omega() | |
| reg += torch.mean(torch.abs(torch.diff(acc))) * 0.01 | |
| reg += torch.mean(torch.abs(torch.diff(omega))) * 0.1 | |
| reg_a_motion = self.v[:-1] * ((torch.sin(self.yaw[1:]) - torch.sin(self.yaw[:-1])) / (omega[:-1] + 1e-6)) | |
| reg_b_motion = -self.v[:-1] * ((torch.cos(self.yaw[1:]) - torch.cos(self.yaw[:-1])) / (omega[:-1] + 1e-6)) | |
| reg_a = self.a[:-1] + reg_a_motion | |
| reg_b = self.b[:-1] + reg_b_motion | |
| reg += torch.mean((reg_a - self.a[1:])**2 + (reg_b - self.b[1:])**2) * 1 | |
| return reg | |
| def pos_loss(self): | |
| return torch.mean((self.a - self.input_a) ** 2 + (self.b - self.input_b) ** 2) * 10 | |
| def create_unicycle_model(train_cams, model_path, opt_iter=0, opt_pos=False, data_type='kitti'): | |
| unicycle_models = {} | |
| if data_type == 'kitti': | |
| cameras = [cam for cam in train_cams if 'cam_0' in cam.image_name] | |
| elif data_type == 'waymo': | |
| cameras = [cam for cam in train_cams if 'cam_1' in cam.image_name] | |
| elif data_type == 'nuscenes': | |
| cameras = [cam for cam in train_cams if (('CAM_FRONT' in cam.image_name) and ('LEFT' not in cam.image_name) and ('RIGHT' not in cam.image_name))] | |
| elif data_type == 'pandaset': | |
| cameras = [cam for cam in train_cams if 'front_camera' in cam.image_name] | |
| else: | |
| raise NotImplementedError | |
| cameras = sorted(cameras, key=lambda x: x.timestamp) | |
| os.makedirs(os.path.join(model_path, "unicycle"), exist_ok=True) | |
| start_time = cameras[0].timestamp | |
| all_centers, all_heights, all_eulers, all_timestamps = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) | |
| for cam in cameras: | |
| t = cam.timestamp - start_time | |
| for track_id, b2w in cam.dynamics.items(): | |
| all_centers[track_id].append(b2w[[0, 2], 3]) | |
| all_heights[track_id].append(b2w[1, 3]) | |
| eulers = roma.rotmat_to_euler('xzy', b2w[:3, :3]) | |
| all_eulers[track_id].append(-eulers + torch.pi / 2) | |
| # all_eulers[track_id].append(eulers) | |
| all_timestamps[track_id].append(t) | |
| for track_id in all_centers.keys(): | |
| centers = torch.stack(all_centers[track_id], dim=0).cuda() | |
| timestamps = torch.tensor(all_timestamps[track_id]).cuda() | |
| heights = torch.tensor(all_heights[track_id]).cuda() | |
| eulers = torch.stack(all_eulers[track_id]).cuda() | |
| model = unicycle(timestamps, centers.clone(), eulers.clone(), heights.clone()) | |
| model.visualize(os.path.join(model_path, "unicycle", f"{track_id}_init.png")) | |
| l = [ | |
| {'params': [model.v], 'lr': 1e-3, "name": "v"}, | |
| {'params': [model.yaw], 'lr': 1e-4, "name": "yaw"}, | |
| ] | |
| if opt_pos: | |
| l.extend([ | |
| {'params': [model.a], 'lr': 1e-3, "name": "a"}, | |
| {'params': [model.b], 'lr': 1e-3, "name": "b"}, | |
| ]) | |
| optimizer = optim.Adam(l, lr=0.0) | |
| t_range = tqdm(range(opt_iter), desc=f"Fitting {track_id}") | |
| for iter in t_range: | |
| loss = 5e-3 * model.reg_loss() + 1e-3 * model.pos_loss() | |
| t_range.set_postfix({'loss': loss.item()}) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| unicycle_models[track_id] = {'model': model, | |
| 'optimizer': optimizer, | |
| 'input_centers': centers} | |
| model.visualize(os.path.join(model_path, "unicycle", f"{track_id}_iter0.png")) | |
| torch.save(model.capture(), os.path.join(model_path, f"ckpts/unicycle_{track_id}.pth")) | |
| return unicycle_models | |
| if __name__ == "__main__": | |
| from scene import Scene, GaussianModel | |
| from omegaconf import OmegaConf | |
| from argparse import ArgumentParser | |
| parser = ArgumentParser(description="Training script parameters") | |
| parser.add_argument("--base_cfg", type=str, default="./configs/gs_base.yaml") | |
| parser.add_argument("--data_cfg", type=str, default="./configs/nusc.yaml") | |
| parser.add_argument("--source_path", type=str, default="") | |
| parser.add_argument("--model_path", type=str, default="") | |
| args = parser.parse_args() | |
| cfg = OmegaConf.merge(OmegaConf.load(args.base_cfg), OmegaConf.load(args.data_cfg)) | |
| if len(args.source_path) > 0: | |
| cfg.source_path = args.source_path | |
| if len(args.model_path) > 0: | |
| cfg.model_path = args.model_path | |
| gaussians = GaussianModel(3, feat_mutable=True) | |
| print("loading scene...") | |
| scene = Scene(cfg, gaussians, data_type=cfg.data_type) | |
| create_unicycle_model(scene.getTrainCameras(), cfg.model_path, 500, False, cfg.data_type) |