from must3r.model import ActivationType, apply_activation from dust3r.post_process import estimate_focal_knowing_depth import torch import random, math, roma import torchvision.transforms.functional as TF from tensordict import tensorclass import torch.nn.functional as F def save_checkpoint(model: torch.nn.Module, path: str) -> None: while True: try: torch.save(model.state_dict(), path) break except Exception as e: print(e) continue def load_checkpoint(model: torch.nn.Module, ckpt_state_dict_raw: dict, strict = False) -> torch.nn.Module: try: if strict: model.load_state_dict(ckpt_state_dict_raw) else: model_dict = model.state_dict() ckpt_state_dict = {k: v for k, v in ckpt_state_dict_raw.items() if k in model_dict and v.shape == model_dict[k].shape} model_dict.update(ckpt_state_dict) model.load_state_dict(model_dict) print(f'The following keys is in ckpt but not loaded: {set(ckpt_state_dict_raw.keys()) - set(ckpt_state_dict.keys())}') except Exception as e: print(e) finally: return model def random_color_jitter(vid, brightness, contrast, saturation, hue = None): ''' vid of shape [num_frames, num_channels, height, width] ''' assert vid.ndim == 4 if brightness > 0: brightness_factor = random.uniform(1, 1 + brightness) else: brightness_factor = None if contrast > 0: contrast_factor = random.uniform(1, 1 + contrast) else: contrast_factor = None if saturation > 0: saturation_factor = random.uniform(1, 1 + saturation) else: saturation_factor = None if hue > 0: hue_factor = random.uniform(0, hue) else: hue_factor = None vid_transforms = [] if brightness is not None: vid_transforms.append(lambda img: TF.adjust_brightness(img, brightness_factor)) if saturation is not None: vid_transforms.append(lambda img: TF.adjust_saturation(img, saturation_factor)) # if hue is not None: # vid_transforms.append(lambda img: TF.adjust_hue(img, hue_factor)) if contrast is not None: vid_transforms.append(lambda img: TF.adjust_contrast(img, contrast_factor)) random.shuffle(vid_transforms) for transform in vid_transforms: vid = transform(vid) return vid @tensorclass class BatchedVideoDatapoint: """ This class represents a batch of videos with associated annotations and metadata. Attributes: img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. """ img_batch: torch.FloatTensor masks: torch.BoolTensor flat_obj_to_img_idx: torch.IntTensor features_3d: torch.FloatTensor = None def pin_memory(self, device=None): return self.apply(torch.Tensor.pin_memory, device=device) @property def num_frames(self) -> int: """ Returns the number of frames per video. """ return self.img_batch.shape[0] @property def num_videos(self) -> int: """ Returns the number of videos in the batch. """ return self.img_batch.shape[1] @property def flat_img_batch(self) -> torch.FloatTensor: """ Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] """ return self.img_batch.transpose(0, 1).flatten(0, 1) @property def flat_features_3d(self) -> torch.FloatTensor: """ Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] """ return self.features_3d.transpose(0, 1).flatten(0, 1) def sigmoid_focal_loss( inputs, targets, alpha: float = 0.5, gamma: float = 2, ): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: focal loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction = "none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss def positional_encoding(positions, freqs, dim = 1): """ Applies positional encoding along a specified dimension, expanding the dimension size based on the number of frequency bands. Args: positions (torch.Tensor): Input tensor representing positions (e.g., shape (1, 3, 256, 256)). freqs (int): Number of frequency bands for encoding. dim (int): Dimension along which to apply encoding. Default is 1. Returns: torch.Tensor: Tensor with positional encoding applied along the specified dimension. """ # Ensure that the specified dimension is valid assert dim >= 0 and dim < positions.ndim, "Invalid dimension specified." # Generate frequency bands freq_bands = (2 ** torch.arange(freqs, dtype=positions.dtype, device=positions.device)) # Apply frequency bands to positions at the specified dimension expanded_positions = positions.unsqueeze(dim + 1) * freq_bands.view(-1, *([1] * (positions.ndim - dim - 1))) # Reshape to combine the new frequency dimension with the specified dim encoded_positions = expanded_positions.reshape( *positions.shape[:dim], -1, *positions.shape[dim + 1:] ) # Concatenate sine and cosine encodings positional_encoded = torch.cat([torch.sin(encoded_positions), torch.cos(encoded_positions), positions], dim = dim) return positional_encoded @torch.autocast("cuda", dtype=torch.float32) def postprocess_must3r_output(pointmaps, pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True): out = {} channels = pointmaps.shape[-1] out['pts3d'] = pointmaps[..., :3] out['pts3d'] = apply_activation(out['pts3d'], activation = pointmaps_activation) if channels >= 6: out['pts3d_local'] = pointmaps[..., 3:6] out['pts3d_local'] = apply_activation(out['pts3d_local'], activation = pointmaps_activation) if channels == 4 or channels == 7: out['conf'] = 1.0 + pointmaps[..., -1].exp() if compute_cam: batch_dims = out['pts3d'].shape[:-3] num_batch_dims = len(batch_dims) H, W = out['conf'].shape[-2:] pp = torch.tensor((W / 2, H / 2), device = out['pts3d'].device) focal = estimate_focal_knowing_depth(out['pts3d_local'].reshape(math.prod(batch_dims), H, W, 3), pp, focal_mode='weiszfeld') out['focal'] = focal.reshape(*batch_dims) R, T = roma.rigid_points_registration( out['pts3d_local'].reshape(*batch_dims, -1, 3), out['pts3d'].reshape(*batch_dims, -1, 3), weights = out['conf'].reshape(*batch_dims, -1) - 1.0, compute_scaling = False) c2w = torch.eye(4, device=out['pts3d'].device) c2w = c2w.view(*([1] * num_batch_dims), 4, 4).repeat(*batch_dims, 1, 1) c2w[..., :3, :3] = R c2w[..., :3, 3] = T.view(*batch_dims, 3) out['c2w'] = c2w # pixel grid ys, xs = torch.meshgrid( torch.arange(H, device = out['pts3d'].device), torch.arange(W, device = out['pts3d'].device), indexing = 'ij' ) # broadcast to batch f = out['focal'].reshape(*batch_dims, 1, 1) # assume fx = fy = focal x = (xs - pp[0]) / f y = (ys - pp[1]) / f # directions in camera frame d_cam = torch.stack([x, y, torch.ones_like(x)], dim=-1) d_cam = F.normalize(d_cam, dim=-1) # rotate to world frame d_world = torch.einsum('...ij,...hwj->...hwi', R, d_cam) # camera center in world frame o_world = c2w[..., :3, 3].view(*batch_dims, 1, 1, 3).expand(*batch_dims, H, W, 3) # Plücker coordinates: (m, d) with m = o × d m_world = torch.cross(o_world, d_world, dim = -1) plucker = torch.cat([m_world, d_world], dim = -1) # shape: (*batch, H, W, 6) out['ray_origin'] = o_world out['ray_dir'] = d_world out['ray_plucker'] = plucker return out def to_device(x, device = 'cuda'): if isinstance(x, torch.Tensor): return x.to(device) elif isinstance(x, dict): return {k: to_device(v, device) for k, v in x.items()} elif isinstance(x, list): return [to_device(v, device) for v in x] elif isinstance(x, tuple): return tuple(to_device(v, device) for v in x) elif isinstance(x, int) or isinstance(x, float) or isinstance(x, str) or x is None: return x else: raise ValueError(f'Unsupported type {type(x)}')