3AM / training_utils.py
nycu-cplab's picture
app overall
0bb5fcf
from must3r.model import ActivationType, apply_activation
from dust3r.post_process import estimate_focal_knowing_depth
import torch
import random, math, roma
import torchvision.transforms.functional as TF
from tensordict import tensorclass
import torch.nn.functional as F
def save_checkpoint(model: torch.nn.Module, path: str) -> None:
while True:
try:
torch.save(model.state_dict(), path)
break
except Exception as e:
print(e)
continue
def load_checkpoint(model: torch.nn.Module, ckpt_state_dict_raw: dict, strict = False) -> torch.nn.Module:
try:
if strict:
model.load_state_dict(ckpt_state_dict_raw)
else:
model_dict = model.state_dict()
ckpt_state_dict = {k: v for k, v in ckpt_state_dict_raw.items() if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(ckpt_state_dict)
model.load_state_dict(model_dict)
print(f'The following keys is in ckpt but not loaded: {set(ckpt_state_dict_raw.keys()) - set(ckpt_state_dict.keys())}')
except Exception as e:
print(e)
finally:
return model
def random_color_jitter(vid, brightness, contrast, saturation, hue = None):
'''
vid of shape [num_frames, num_channels, height, width]
'''
assert vid.ndim == 4
if brightness > 0:
brightness_factor = random.uniform(1, 1 + brightness)
else:
brightness_factor = None
if contrast > 0:
contrast_factor = random.uniform(1, 1 + contrast)
else:
contrast_factor = None
if saturation > 0:
saturation_factor = random.uniform(1, 1 + saturation)
else:
saturation_factor = None
if hue > 0:
hue_factor = random.uniform(0, hue)
else:
hue_factor = None
vid_transforms = []
if brightness is not None:
vid_transforms.append(lambda img: TF.adjust_brightness(img, brightness_factor))
if saturation is not None:
vid_transforms.append(lambda img: TF.adjust_saturation(img, saturation_factor))
# if hue is not None:
# vid_transforms.append(lambda img: TF.adjust_hue(img, hue_factor))
if contrast is not None:
vid_transforms.append(lambda img: TF.adjust_contrast(img, contrast_factor))
random.shuffle(vid_transforms)
for transform in vid_transforms:
vid = transform(vid)
return vid
@tensorclass
class BatchedVideoDatapoint:
"""
This class represents a batch of videos with associated annotations and metadata.
Attributes:
img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
"""
img_batch: torch.FloatTensor
masks: torch.BoolTensor
flat_obj_to_img_idx: torch.IntTensor
features_3d: torch.FloatTensor = None
def pin_memory(self, device=None):
return self.apply(torch.Tensor.pin_memory, device=device)
@property
def num_frames(self) -> int:
"""
Returns the number of frames per video.
"""
return self.img_batch.shape[0]
@property
def num_videos(self) -> int:
"""
Returns the number of videos in the batch.
"""
return self.img_batch.shape[1]
@property
def flat_img_batch(self) -> torch.FloatTensor:
"""
Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
"""
return self.img_batch.transpose(0, 1).flatten(0, 1)
@property
def flat_features_3d(self) -> torch.FloatTensor:
"""
Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
"""
return self.features_3d.transpose(0, 1).flatten(0, 1)
def sigmoid_focal_loss(
inputs,
targets,
alpha: float = 0.5,
gamma: float = 2,
):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
focal loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction = "none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss
def positional_encoding(positions, freqs, dim = 1):
"""
Applies positional encoding along a specified dimension, expanding the
dimension size based on the number of frequency bands.
Args:
positions (torch.Tensor): Input tensor representing positions (e.g., shape (1, 3, 256, 256)).
freqs (int): Number of frequency bands for encoding.
dim (int): Dimension along which to apply encoding. Default is 1.
Returns:
torch.Tensor: Tensor with positional encoding applied along the specified dimension.
"""
# Ensure that the specified dimension is valid
assert dim >= 0 and dim < positions.ndim, "Invalid dimension specified."
# Generate frequency bands
freq_bands = (2 ** torch.arange(freqs, dtype=positions.dtype, device=positions.device))
# Apply frequency bands to positions at the specified dimension
expanded_positions = positions.unsqueeze(dim + 1) * freq_bands.view(-1, *([1] * (positions.ndim - dim - 1)))
# Reshape to combine the new frequency dimension with the specified dim
encoded_positions = expanded_positions.reshape(
*positions.shape[:dim], -1, *positions.shape[dim + 1:]
)
# Concatenate sine and cosine encodings
positional_encoded = torch.cat([torch.sin(encoded_positions), torch.cos(encoded_positions), positions], dim = dim)
return positional_encoded
@torch.autocast("cuda", dtype=torch.float32)
def postprocess_must3r_output(pointmaps, pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True):
out = {}
channels = pointmaps.shape[-1]
out['pts3d'] = pointmaps[..., :3]
out['pts3d'] = apply_activation(out['pts3d'], activation = pointmaps_activation)
if channels >= 6:
out['pts3d_local'] = pointmaps[..., 3:6]
out['pts3d_local'] = apply_activation(out['pts3d_local'], activation = pointmaps_activation)
if channels == 4 or channels == 7:
out['conf'] = 1.0 + pointmaps[..., -1].exp()
if compute_cam:
batch_dims = out['pts3d'].shape[:-3]
num_batch_dims = len(batch_dims)
H, W = out['conf'].shape[-2:]
pp = torch.tensor((W / 2, H / 2), device = out['pts3d'].device)
focal = estimate_focal_knowing_depth(out['pts3d_local'].reshape(math.prod(batch_dims), H, W, 3), pp,
focal_mode='weiszfeld')
out['focal'] = focal.reshape(*batch_dims)
R, T = roma.rigid_points_registration(
out['pts3d_local'].reshape(*batch_dims, -1, 3),
out['pts3d'].reshape(*batch_dims, -1, 3),
weights = out['conf'].reshape(*batch_dims, -1) - 1.0, compute_scaling = False)
c2w = torch.eye(4, device=out['pts3d'].device)
c2w = c2w.view(*([1] * num_batch_dims), 4, 4).repeat(*batch_dims, 1, 1)
c2w[..., :3, :3] = R
c2w[..., :3, 3] = T.view(*batch_dims, 3)
out['c2w'] = c2w
# pixel grid
ys, xs = torch.meshgrid(
torch.arange(H, device = out['pts3d'].device),
torch.arange(W, device = out['pts3d'].device),
indexing = 'ij'
)
# broadcast to batch
f = out['focal'].reshape(*batch_dims, 1, 1) # assume fx = fy = focal
x = (xs - pp[0]) / f
y = (ys - pp[1]) / f
# directions in camera frame
d_cam = torch.stack([x, y, torch.ones_like(x)], dim=-1)
d_cam = F.normalize(d_cam, dim=-1)
# rotate to world frame
d_world = torch.einsum('...ij,...hwj->...hwi', R, d_cam)
# camera center in world frame
o_world = c2w[..., :3, 3].view(*batch_dims, 1, 1, 3).expand(*batch_dims, H, W, 3)
# Plücker coordinates: (m, d) with m = o × d
m_world = torch.cross(o_world, d_world, dim = -1)
plucker = torch.cat([m_world, d_world], dim = -1) # shape: (*batch, H, W, 6)
out['ray_origin'] = o_world
out['ray_dir'] = d_world
out['ray_plucker'] = plucker
return out
def to_device(x, device = 'cuda'):
if isinstance(x, torch.Tensor):
return x.to(device)
elif isinstance(x, dict):
return {k: to_device(v, device) for k, v in x.items()}
elif isinstance(x, list):
return [to_device(v, device) for v in x]
elif isinstance(x, tuple):
return tuple(to_device(v, device) for v in x)
elif isinstance(x, int) or isinstance(x, float) or isinstance(x, str) or x is None:
return x
else:
raise ValueError(f'Unsupported type {type(x)}')