Spaces:
Running
on
Zero
Running
on
Zero
| from must3r.model import ActivationType, apply_activation | |
| from dust3r.post_process import estimate_focal_knowing_depth | |
| import torch | |
| import random, math, roma | |
| import torchvision.transforms.functional as TF | |
| from tensordict import tensorclass | |
| import torch.nn.functional as F | |
| def save_checkpoint(model: torch.nn.Module, path: str) -> None: | |
| while True: | |
| try: | |
| torch.save(model.state_dict(), path) | |
| break | |
| except Exception as e: | |
| print(e) | |
| continue | |
| def load_checkpoint(model: torch.nn.Module, ckpt_state_dict_raw: dict, strict = False) -> torch.nn.Module: | |
| try: | |
| if strict: | |
| model.load_state_dict(ckpt_state_dict_raw) | |
| else: | |
| model_dict = model.state_dict() | |
| ckpt_state_dict = {k: v for k, v in ckpt_state_dict_raw.items() if k in model_dict and v.shape == model_dict[k].shape} | |
| model_dict.update(ckpt_state_dict) | |
| model.load_state_dict(model_dict) | |
| print(f'The following keys is in ckpt but not loaded: {set(ckpt_state_dict_raw.keys()) - set(ckpt_state_dict.keys())}') | |
| except Exception as e: | |
| print(e) | |
| finally: | |
| return model | |
| def random_color_jitter(vid, brightness, contrast, saturation, hue = None): | |
| ''' | |
| vid of shape [num_frames, num_channels, height, width] | |
| ''' | |
| assert vid.ndim == 4 | |
| if brightness > 0: | |
| brightness_factor = random.uniform(1, 1 + brightness) | |
| else: | |
| brightness_factor = None | |
| if contrast > 0: | |
| contrast_factor = random.uniform(1, 1 + contrast) | |
| else: | |
| contrast_factor = None | |
| if saturation > 0: | |
| saturation_factor = random.uniform(1, 1 + saturation) | |
| else: | |
| saturation_factor = None | |
| if hue > 0: | |
| hue_factor = random.uniform(0, hue) | |
| else: | |
| hue_factor = None | |
| vid_transforms = [] | |
| if brightness is not None: | |
| vid_transforms.append(lambda img: TF.adjust_brightness(img, brightness_factor)) | |
| if saturation is not None: | |
| vid_transforms.append(lambda img: TF.adjust_saturation(img, saturation_factor)) | |
| # if hue is not None: | |
| # vid_transforms.append(lambda img: TF.adjust_hue(img, hue_factor)) | |
| if contrast is not None: | |
| vid_transforms.append(lambda img: TF.adjust_contrast(img, contrast_factor)) | |
| random.shuffle(vid_transforms) | |
| for transform in vid_transforms: | |
| vid = transform(vid) | |
| return vid | |
| class BatchedVideoDatapoint: | |
| """ | |
| This class represents a batch of videos with associated annotations and metadata. | |
| Attributes: | |
| img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. | |
| obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. | |
| masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. | |
| """ | |
| img_batch: torch.FloatTensor | |
| masks: torch.BoolTensor | |
| flat_obj_to_img_idx: torch.IntTensor | |
| features_3d: torch.FloatTensor = None | |
| def pin_memory(self, device=None): | |
| return self.apply(torch.Tensor.pin_memory, device=device) | |
| def num_frames(self) -> int: | |
| """ | |
| Returns the number of frames per video. | |
| """ | |
| return self.img_batch.shape[0] | |
| def num_videos(self) -> int: | |
| """ | |
| Returns the number of videos in the batch. | |
| """ | |
| return self.img_batch.shape[1] | |
| def flat_img_batch(self) -> torch.FloatTensor: | |
| """ | |
| Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] | |
| """ | |
| return self.img_batch.transpose(0, 1).flatten(0, 1) | |
| def flat_features_3d(self) -> torch.FloatTensor: | |
| """ | |
| Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] | |
| """ | |
| return self.features_3d.transpose(0, 1).flatten(0, 1) | |
| def sigmoid_focal_loss( | |
| inputs, | |
| targets, | |
| alpha: float = 0.5, | |
| gamma: float = 2, | |
| ): | |
| """ | |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| alpha: (optional) Weighting factor in range (0,1) to balance | |
| positive vs negative examples. Default = -1 (no weighting). | |
| gamma: Exponent of the modulating factor (1 - p_t) to | |
| balance easy vs hard examples. | |
| Returns: | |
| focal loss tensor | |
| """ | |
| prob = inputs.sigmoid() | |
| ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction = "none") | |
| p_t = prob * targets + (1 - prob) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| return loss | |
| def positional_encoding(positions, freqs, dim = 1): | |
| """ | |
| Applies positional encoding along a specified dimension, expanding the | |
| dimension size based on the number of frequency bands. | |
| Args: | |
| positions (torch.Tensor): Input tensor representing positions (e.g., shape (1, 3, 256, 256)). | |
| freqs (int): Number of frequency bands for encoding. | |
| dim (int): Dimension along which to apply encoding. Default is 1. | |
| Returns: | |
| torch.Tensor: Tensor with positional encoding applied along the specified dimension. | |
| """ | |
| # Ensure that the specified dimension is valid | |
| assert dim >= 0 and dim < positions.ndim, "Invalid dimension specified." | |
| # Generate frequency bands | |
| freq_bands = (2 ** torch.arange(freqs, dtype=positions.dtype, device=positions.device)) | |
| # Apply frequency bands to positions at the specified dimension | |
| expanded_positions = positions.unsqueeze(dim + 1) * freq_bands.view(-1, *([1] * (positions.ndim - dim - 1))) | |
| # Reshape to combine the new frequency dimension with the specified dim | |
| encoded_positions = expanded_positions.reshape( | |
| *positions.shape[:dim], -1, *positions.shape[dim + 1:] | |
| ) | |
| # Concatenate sine and cosine encodings | |
| positional_encoded = torch.cat([torch.sin(encoded_positions), torch.cos(encoded_positions), positions], dim = dim) | |
| return positional_encoded | |
| def postprocess_must3r_output(pointmaps, pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True): | |
| out = {} | |
| channels = pointmaps.shape[-1] | |
| out['pts3d'] = pointmaps[..., :3] | |
| out['pts3d'] = apply_activation(out['pts3d'], activation = pointmaps_activation) | |
| if channels >= 6: | |
| out['pts3d_local'] = pointmaps[..., 3:6] | |
| out['pts3d_local'] = apply_activation(out['pts3d_local'], activation = pointmaps_activation) | |
| if channels == 4 or channels == 7: | |
| out['conf'] = 1.0 + pointmaps[..., -1].exp() | |
| if compute_cam: | |
| batch_dims = out['pts3d'].shape[:-3] | |
| num_batch_dims = len(batch_dims) | |
| H, W = out['conf'].shape[-2:] | |
| pp = torch.tensor((W / 2, H / 2), device = out['pts3d'].device) | |
| focal = estimate_focal_knowing_depth(out['pts3d_local'].reshape(math.prod(batch_dims), H, W, 3), pp, | |
| focal_mode='weiszfeld') | |
| out['focal'] = focal.reshape(*batch_dims) | |
| R, T = roma.rigid_points_registration( | |
| out['pts3d_local'].reshape(*batch_dims, -1, 3), | |
| out['pts3d'].reshape(*batch_dims, -1, 3), | |
| weights = out['conf'].reshape(*batch_dims, -1) - 1.0, compute_scaling = False) | |
| c2w = torch.eye(4, device=out['pts3d'].device) | |
| c2w = c2w.view(*([1] * num_batch_dims), 4, 4).repeat(*batch_dims, 1, 1) | |
| c2w[..., :3, :3] = R | |
| c2w[..., :3, 3] = T.view(*batch_dims, 3) | |
| out['c2w'] = c2w | |
| # pixel grid | |
| ys, xs = torch.meshgrid( | |
| torch.arange(H, device = out['pts3d'].device), | |
| torch.arange(W, device = out['pts3d'].device), | |
| indexing = 'ij' | |
| ) | |
| # broadcast to batch | |
| f = out['focal'].reshape(*batch_dims, 1, 1) # assume fx = fy = focal | |
| x = (xs - pp[0]) / f | |
| y = (ys - pp[1]) / f | |
| # directions in camera frame | |
| d_cam = torch.stack([x, y, torch.ones_like(x)], dim=-1) | |
| d_cam = F.normalize(d_cam, dim=-1) | |
| # rotate to world frame | |
| d_world = torch.einsum('...ij,...hwj->...hwi', R, d_cam) | |
| # camera center in world frame | |
| o_world = c2w[..., :3, 3].view(*batch_dims, 1, 1, 3).expand(*batch_dims, H, W, 3) | |
| # Plücker coordinates: (m, d) with m = o × d | |
| m_world = torch.cross(o_world, d_world, dim = -1) | |
| plucker = torch.cat([m_world, d_world], dim = -1) # shape: (*batch, H, W, 6) | |
| out['ray_origin'] = o_world | |
| out['ray_dir'] = d_world | |
| out['ray_plucker'] = plucker | |
| return out | |
| def to_device(x, device = 'cuda'): | |
| if isinstance(x, torch.Tensor): | |
| return x.to(device) | |
| elif isinstance(x, dict): | |
| return {k: to_device(v, device) for k, v in x.items()} | |
| elif isinstance(x, list): | |
| return [to_device(v, device) for v in x] | |
| elif isinstance(x, tuple): | |
| return tuple(to_device(v, device) for v in x) | |
| elif isinstance(x, int) or isinstance(x, float) or isinstance(x, str) or x is None: | |
| return x | |
| else: | |
| raise ValueError(f'Unsupported type {type(x)}') |