| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | def __init__(self, fs: int = 200, patch_seconds: float = 1.0, overlap_seconds: float = 0.1, embed_dim: int = 512): |
| | super().__init__() |
| |
|
| | self.patch_size = int(round(patch_seconds * fs)) |
| | self.overlap_size = int(round(overlap_seconds * fs)) |
| |
|
| | self.step = self.patch_size - self.overlap_size |
| |
|
| | self.linear = nn.Linear(self.patch_size, embed_dim, bias=False) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | patches = x.unfold(dimension=-1, size=self.patch_size, step=self.step) |
| | return self.linear(patches) |
| |
|
| |
|
| | class PosEnc(nn.Module): |
| | def __init__(self, n_freqs: int = 4, embed_dim: int = 512): |
| | super().__init__() |
| |
|
| | freqs = torch.linspace(1.0, 10.0, n_freqs) |
| | self.register_buffer("freq_matrix", torch.cartesian_prod(freqs, freqs, freqs, freqs).transpose(1, 0)) |
| |
|
| | fourier_features_dim = 2 * (n_freqs**4) |
| |
|
| | self.fourier_linear = nn.Linear(fourier_features_dim, embed_dim, bias=False) |
| | self.learned_linear = nn.Sequential(nn.Linear(4, embed_dim, bias=False), nn.GELU(), nn.LayerNorm(embed_dim)) |
| |
|
| | self.final_norm = nn.LayerNorm(embed_dim) |
| |
|
| | def forward(self, coords: torch.Tensor): |
| | phases = torch.matmul(coords, self.freq_matrix) |
| |
|
| | fourier_features = torch.cat([torch.sin(phases), torch.cos(phases)], -1) |
| | fourier_emb = self.fourier_linear(fourier_features) |
| |
|
| | learned_emb = self.learned_linear(coords) |
| |
|
| | return self.final_norm(fourier_emb + learned_emb) |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, embed_dim: int, heads: int, dropout: float = 0.0): |
| | super().__init__() |
| |
|
| | assert embed_dim % heads == 0, "dim must be divisible by heads" |
| |
|
| | self.pre_attn_norm = nn.LayerNorm(embed_dim) |
| | self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout, batch_first=True) |
| |
|
| | self.pre_ffn_norm = nn.LayerNorm(embed_dim) |
| | self.ffn = nn.Sequential(nn.Linear(embed_dim, 4 * embed_dim), nn.GELU(), nn.Linear(4 * embed_dim, embed_dim)) |
| |
|
| | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | attn_in = self.pre_attn_norm(x) |
| |
|
| | attn_out, _ = self.attn(attn_in, attn_in, attn_in) |
| | x = x + attn_out |
| |
|
| | ffn_in = self.pre_ffn_norm(x) |
| |
|
| | ffn_out = self.ffn(ffn_in) |
| | x = x + ffn_out |
| |
|
| | return x, ffn_out |
| |
|
| |
|
| | class TransformerEncoderDecoder(nn.Module): |
| | def __init__(self, embed_dim: int = 512, depth: int = 16, heads: int = 8): |
| | super().__init__() |
| |
|
| | self.layers = nn.ModuleList([TransformerBlock(embed_dim, heads) for _ in range(depth)]) |
| | self.final_norm = nn.LayerNorm(embed_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: |
| | intermediate = [] |
| |
|
| | for layer in self.layers: |
| | x, ffn_out = layer(x) |
| | intermediate.append(ffn_out) |
| |
|
| | return self.final_norm(x), intermediate |
| |
|
| |
|
| | class MAEDecoder(nn.Module): |
| | def __init__(self, embed_dim: int = 512, decoder_depth: int = 4, decoder_heads: int = 8, patch_size: int = 200): |
| | super().__init__() |
| |
|
| | |
| | |
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | nn.init.normal_(self.mask_token, std=0.02) |
| |
|
| | |
| | |
| | self.decoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=decoder_depth, heads=decoder_heads) |
| |
|
| | |
| | |
| | self.predict = nn.Linear(embed_dim, patch_size, bias=True) |
| |
|
| | def forward(self, x_visible: torch.Tensor, pos_enc: nn.Module, coords: torch.Tensor, mask: torch.Tensor): |
| | B, N_Total, D = coords.shape[0], coords.shape[1], x_visible.shape[-1] |
| |
|
| | |
| | |
| | x_full = self.mask_token.expand(B, N_Total, D).clone() |
| |
|
| | |
| | |
| | for i in range(B): |
| | |
| | x_full[i, mask[i]] = x_visible[i] |
| |
|
| | |
| | |
| | |
| | pos_emb = pos_enc(coords) |
| |
|
| | |
| | x_full = x_full + pos_emb |
| |
|
| | |
| | |
| | |
| | x_decoded, _ = self.decoder(x_full) |
| |
|
| | |
| | |
| | prediction = self.predict(x_decoded) |
| |
|
| | return prediction |
| |
|
| |
|
| | def generate_mask(coords: torch.Tensor, mask_ratio: float = 0.55, spatial_radius: float = 3.0, temporal_radius: float = 3.0): |
| | B, N, _ = coords.shape |
| | device = coords.device |
| |
|
| | |
| | num_masked_target = int(mask_ratio * N) |
| |
|
| | |
| | mask = torch.ones(B, N, dtype=torch.bool, device=device) |
| |
|
| | for b in range(B): |
| | spatial_coords = coords[b, :, :3] |
| | temporal_coords = coords[b, :, 3] |
| |
|
| | |
| | |
| | while (~mask[b]).sum() < num_masked_target: |
| | |
| | seed_idx = torch.randint(0, N, (1,)).item() |
| |
|
| | |
| | seed_spatial = spatial_coords[seed_idx] |
| | dists_spatial = torch.norm(spatial_coords - seed_spatial, dim=1) |
| |
|
| | seed_temporal = temporal_coords[seed_idx] |
| | dists_temporal = torch.abs(temporal_coords - seed_temporal) |
| |
|
| | |
| | in_block = (dists_spatial <= spatial_radius) & (dists_temporal <= temporal_radius) |
| |
|
| | |
| | mask[b, in_block] = False |
| |
|
| | |
| | |
| |
|
| | |
| | masked_indices = torch.where(mask[b] == False)[0] |
| | num_current_masked = len(masked_indices) |
| |
|
| | if num_current_masked > num_masked_target: |
| | |
| | |
| | shuffled_indices = masked_indices[torch.randperm(num_current_masked)] |
| |
|
| | |
| | |
| | excess_indices = shuffled_indices[num_masked_target:] |
| |
|
| | mask[b, excess_indices] = True |
| |
|
| | return mask |
| |
|
| |
|
| | class MAE(nn.Module): |
| | def __init__( |
| | self, |
| | |
| | fs: int = 200, |
| | patch_seconds: float = 1.0, |
| | overlap_seconds: float = 0.1, |
| | |
| | embed_dim: int = 512, |
| | encoder_depth: int = 12, |
| | encoder_heads: int = 8, |
| | decoder_depth: int = 4, |
| | decoder_heads: int = 8, |
| | |
| | mask_ratio: float = 0.55, |
| | aux_loss_weight: float = 0.1, |
| | ): |
| | super().__init__() |
| |
|
| | self.embed_dim = embed_dim |
| | self.mask_ratio = mask_ratio |
| | self.aux_loss_weight = aux_loss_weight |
| |
|
| | |
| | self.patch_embed = PatchEmbed(fs, patch_seconds, overlap_seconds, embed_dim) |
| |
|
| | |
| | self.patch_size = self.patch_embed.patch_size |
| | self.step = self.patch_embed.step |
| |
|
| | |
| | self.pos_enc = PosEnc(n_freqs=4, embed_dim=embed_dim) |
| |
|
| | |
| | self.encoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=encoder_depth, heads=encoder_heads) |
| |
|
| | |
| | self.decoder = MAEDecoder(embed_dim=embed_dim, decoder_depth=decoder_depth, decoder_heads=decoder_heads, patch_size=self.patch_size) |
| |
|
| | |
| | |
| | self.aux_dim = encoder_depth * embed_dim |
| |
|
| | |
| | self.aux_query = nn.Parameter(torch.randn(1, 1, self.aux_dim)) |
| | nn.init.normal_(self.aux_query, std=0.02) |
| |
|
| | |
| | self.aux_linear = nn.Linear(self.aux_dim, embed_dim, bias=False) |
| |
|
| | |
| | self.aux_predict = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, self.patch_size)) |
| |
|
| | def prepare_coords(self, xyz: torch.Tensor, num_patches: int): |
| | B, C, _ = xyz.shape |
| | device = xyz.device |
| |
|
| | |
| | time_idx = torch.arange(num_patches, device=device, dtype=torch.float32) |
| |
|
| | |
| | |
| | spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1) |
| |
|
| | |
| | |
| | time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1) |
| |
|
| | |
| | coords = torch.cat([spat, time], dim=-1) |
| |
|
| | |
| | return coords.flatten(1, 2) |
| |
|
| | def forward(self, x: torch.Tensor, xyz: torch.Tensor): |
| | B, _, _ = x.shape |
| |
|
| | |
| | |
| | patches = x.unfold(-1, self.patch_size, self.step) |
| | num_patches = patches.shape[2] |
| |
|
| | |
| | tokens = self.patch_embed.linear(patches) |
| |
|
| | |
| | tokens_flat = tokens.flatten(1, 2) |
| | patches_flat = patches.flatten(1, 2) |
| |
|
| | |
| | coords = self.prepare_coords(xyz, num_patches) |
| |
|
| | |
| | |
| | mask = generate_mask(coords, mask_ratio=self.mask_ratio) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | n_vis = mask[0].sum().item() |
| |
|
| | x_vis = tokens_flat[mask].view(B, n_vis, -1) |
| | coords_vis = coords[mask].view(B, n_vis, -1) |
| |
|
| | |
| | pe_vis = self.pos_enc(coords_vis) |
| | x_vis = x_vis + pe_vis |
| |
|
| | |
| | x_encoded, intermediates = self.encoder(x_vis) |
| |
|
| | |
| | predictions_main = self.decoder(x_visible=x_encoded, pos_enc=self.pos_enc, coords=coords, mask=mask) |
| |
|
| | |
| | |
| | aux_input = torch.cat(intermediates, dim=-1) |
| |
|
| | |
| | |
| | |
| | attn_scores = torch.matmul(aux_input, self.aux_query.transpose(1, 2)) |
| | attn_weights = F.softmax(attn_scores, dim=1) |
| |
|
| | |
| | global_token = torch.sum(attn_weights * aux_input, dim=1, keepdim=True) |
| |
|
| | |
| | global_emb = self.aux_linear(global_token) |
| |
|
| | |
| | |
| | |
| | n_masked = (~mask[0]).sum().item() |
| | coords_masked = coords[~mask].view(B, n_masked, -1) |
| |
|
| | pe_masked = self.pos_enc(coords_masked) |
| |
|
| | |
| | global_expanded = global_emb.expand(-1, n_masked, -1) |
| |
|
| | |
| | aux_pred_in = global_expanded + pe_masked |
| | predictions_aux = self.aux_predict(aux_pred_in) |
| |
|
| | |
| | |
| | target_masked = patches_flat[~mask].view(B, n_masked, -1) |
| |
|
| | |
| | pred_main_masked = predictions_main[~mask].view(B, n_masked, -1) |
| | loss_main = F.l1_loss(pred_main_masked, target_masked) |
| |
|
| | |
| | loss_aux = F.l1_loss(predictions_aux, target_masked) |
| |
|
| | total_loss = loss_main + self.aux_loss_weight * loss_aux |
| |
|
| | return total_loss, predictions_main, mask |
| |
|