File size: 9,710 Bytes
0bb5fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from must3r.model import ActivationType, apply_activation
from dust3r.post_process import estimate_focal_knowing_depth
import torch
import random, math, roma
import torchvision.transforms.functional as TF 
from tensordict import tensorclass
import torch.nn.functional as F

def save_checkpoint(model: torch.nn.Module, path: str) -> None:
    while True:
        try:
            torch.save(model.state_dict(), path)
            break
        except Exception as e:
            print(e)
            continue

def load_checkpoint(model: torch.nn.Module, ckpt_state_dict_raw: dict, strict = False) -> torch.nn.Module:

    try:
        if strict:
            model.load_state_dict(ckpt_state_dict_raw)
        else:
            model_dict = model.state_dict()
            ckpt_state_dict = {k: v for k, v in ckpt_state_dict_raw.items() if k in model_dict and v.shape == model_dict[k].shape}
            model_dict.update(ckpt_state_dict)
            model.load_state_dict(model_dict)
            print(f'The following keys is in ckpt but not loaded: {set(ckpt_state_dict_raw.keys()) - set(ckpt_state_dict.keys())}')
    except Exception as e:
        print(e)
    finally:
        return model
    

def random_color_jitter(vid, brightness, contrast, saturation, hue = None):
    '''
    vid of shape [num_frames, num_channels, height, width]
    '''
    assert vid.ndim == 4
    
    if brightness > 0:
        brightness_factor = random.uniform(1, 1 + brightness)
    else:
        brightness_factor = None
    if contrast > 0:
        contrast_factor = random.uniform(1, 1 + contrast)
    else:
        contrast_factor = None
    if saturation > 0:
        saturation_factor = random.uniform(1, 1 + saturation)
    else:
        saturation_factor = None
    if hue > 0:
        hue_factor = random.uniform(0, hue)
    else:
        hue_factor = None
    vid_transforms = []
    if brightness is not None:
        vid_transforms.append(lambda img: TF.adjust_brightness(img, brightness_factor))
    if saturation is not None:
        vid_transforms.append(lambda img: TF.adjust_saturation(img, saturation_factor))
    # if hue is not None:
    #     vid_transforms.append(lambda img: TF.adjust_hue(img, hue_factor))
    if contrast is not None:
        vid_transforms.append(lambda img: TF.adjust_contrast(img, contrast_factor))
    random.shuffle(vid_transforms)
    for transform in vid_transforms:
        vid = transform(vid)
    return vid


@tensorclass
class BatchedVideoDatapoint:
    """
    This class represents a batch of videos with associated annotations and metadata.
    Attributes:
        img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
        obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
        masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
    """

    img_batch: torch.FloatTensor
    masks: torch.BoolTensor
    flat_obj_to_img_idx: torch.IntTensor
    features_3d: torch.FloatTensor = None
    
    def pin_memory(self, device=None):
        return self.apply(torch.Tensor.pin_memory, device=device)

    @property
    def num_frames(self) -> int:
        """
        Returns the number of frames per video.
        """
        return self.img_batch.shape[0]

    @property
    def num_videos(self) -> int:
        """
        Returns the number of videos in the batch.
        """
        return self.img_batch.shape[1]

    @property
    def flat_img_batch(self) -> torch.FloatTensor:
        """
        Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
        """
        return self.img_batch.transpose(0, 1).flatten(0, 1)
    @property
    def flat_features_3d(self) -> torch.FloatTensor:
        """
        Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
        """
        return self.features_3d.transpose(0, 1).flatten(0, 1)

def sigmoid_focal_loss(
    inputs,
    targets,
    alpha: float = 0.5,
    gamma: float = 2,
):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        focal loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction = "none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss


def positional_encoding(positions, freqs, dim = 1):
    """
    Applies positional encoding along a specified dimension, expanding the
    dimension size based on the number of frequency bands.

    Args:
        positions (torch.Tensor): Input tensor representing positions (e.g., shape (1, 3, 256, 256)).
        freqs (int): Number of frequency bands for encoding.
        dim (int): Dimension along which to apply encoding. Default is 1.

    Returns:
        torch.Tensor: Tensor with positional encoding applied along the specified dimension.
    """
    # Ensure that the specified dimension is valid
    assert dim >= 0 and dim < positions.ndim, "Invalid dimension specified."
    # Generate frequency bands
    freq_bands = (2 ** torch.arange(freqs, dtype=positions.dtype, device=positions.device))
    # Apply frequency bands to positions at the specified dimension
    expanded_positions = positions.unsqueeze(dim + 1) * freq_bands.view(-1, *([1] * (positions.ndim - dim - 1)))

    # Reshape to combine the new frequency dimension with the specified dim
    encoded_positions = expanded_positions.reshape(
        *positions.shape[:dim], -1, *positions.shape[dim + 1:]
    )
    # Concatenate sine and cosine encodings
    positional_encoded = torch.cat([torch.sin(encoded_positions), torch.cos(encoded_positions), positions], dim = dim)

    return positional_encoded


@torch.autocast("cuda", dtype=torch.float32)
def postprocess_must3r_output(pointmaps, pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True):
    out = {}
    channels = pointmaps.shape[-1]
    out['pts3d'] = pointmaps[..., :3]
    out['pts3d'] = apply_activation(out['pts3d'], activation = pointmaps_activation)
    if channels >= 6:
        out['pts3d_local'] = pointmaps[..., 3:6]
        out['pts3d_local'] = apply_activation(out['pts3d_local'], activation = pointmaps_activation)
    if channels == 4 or channels == 7:
        out['conf'] = 1.0 + pointmaps[..., -1].exp()

    if compute_cam:
        batch_dims = out['pts3d'].shape[:-3]
        num_batch_dims = len(batch_dims)
        H, W = out['conf'].shape[-2:]
        pp = torch.tensor((W / 2, H / 2), device = out['pts3d'].device)
        focal = estimate_focal_knowing_depth(out['pts3d_local'].reshape(math.prod(batch_dims), H, W, 3), pp,
                                             focal_mode='weiszfeld')
        out['focal'] = focal.reshape(*batch_dims)

        R, T = roma.rigid_points_registration(
            out['pts3d_local'].reshape(*batch_dims, -1, 3),
            out['pts3d'].reshape(*batch_dims, -1, 3),
            weights = out['conf'].reshape(*batch_dims, -1) - 1.0, compute_scaling = False)

        c2w = torch.eye(4, device=out['pts3d'].device)
        c2w = c2w.view(*([1] * num_batch_dims), 4, 4).repeat(*batch_dims, 1, 1)
        c2w[..., :3, :3] = R
        c2w[..., :3, 3] = T.view(*batch_dims, 3)
        out['c2w'] = c2w

        # pixel grid
        ys, xs = torch.meshgrid(
            torch.arange(H, device = out['pts3d'].device), 
            torch.arange(W, device = out['pts3d'].device), 
            indexing = 'ij'
        )
        # broadcast to batch
        f = out['focal'].reshape(*batch_dims, 1, 1)  # assume fx = fy = focal
        x = (xs - pp[0]) / f
        y = (ys - pp[1]) / f

        # directions in camera frame
        d_cam = torch.stack([x, y, torch.ones_like(x)], dim=-1)
        d_cam = F.normalize(d_cam, dim=-1)

        # rotate to world frame
        d_world = torch.einsum('...ij,...hwj->...hwi', R, d_cam)

        # camera center in world frame
        o_world = c2w[..., :3, 3].view(*batch_dims, 1, 1, 3).expand(*batch_dims, H, W, 3)

        # Plücker coordinates: (m, d) with m = o × d
        m_world = torch.cross(o_world, d_world, dim = -1)
        plucker = torch.cat([m_world, d_world], dim = -1)  # shape: (*batch, H, W, 6)

        out['ray_origin'] = o_world
        out['ray_dir'] = d_world
        out['ray_plucker'] = plucker

    return out


def to_device(x, device = 'cuda'):
    if isinstance(x, torch.Tensor):
        return x.to(device)
    elif isinstance(x, dict):
        return {k: to_device(v, device) for k, v in x.items()}
    elif isinstance(x, list):
        return [to_device(v, device) for v in x]
    elif isinstance(x, tuple):
        return tuple(to_device(v, device) for v in x)
    elif isinstance(x, int) or isinstance(x, float) or isinstance(x, str) or x is None:
        return x
    else:
        raise ValueError(f'Unsupported type {type(x)}')