Upload 3 files
Browse files- mae.py +357 -0
- manas1.pt +3 -0
- modelclass.py +81 -0
mae.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PatchEmbed(nn.Module):
|
| 7 |
+
def __init__(self, fs: int = 200, patch_seconds: float = 1.0, overlap_seconds: float = 0.1, embed_dim: int = 512):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.patch_size = int(round(patch_seconds * fs))
|
| 11 |
+
self.overlap_size = int(round(overlap_seconds * fs))
|
| 12 |
+
|
| 13 |
+
self.step = self.patch_size - self.overlap_size
|
| 14 |
+
|
| 15 |
+
self.linear = nn.Linear(self.patch_size, embed_dim, bias=False)
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
patches = x.unfold(dimension=-1, size=self.patch_size, step=self.step)
|
| 19 |
+
return self.linear(patches)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PosEnc(nn.Module):
|
| 23 |
+
def __init__(self, n_freqs: int = 4, embed_dim: int = 512):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
freqs = torch.linspace(1.0, 10.0, n_freqs)
|
| 27 |
+
self.register_buffer("freq_matrix", torch.cartesian_prod(freqs, freqs, freqs, freqs).transpose(1, 0))
|
| 28 |
+
|
| 29 |
+
fourier_features_dim = 2 * (n_freqs**4)
|
| 30 |
+
|
| 31 |
+
self.fourier_linear = nn.Linear(fourier_features_dim, embed_dim, bias=False)
|
| 32 |
+
self.learned_linear = nn.Sequential(nn.Linear(4, embed_dim, bias=False), nn.GELU(), nn.LayerNorm(embed_dim))
|
| 33 |
+
|
| 34 |
+
self.final_norm = nn.LayerNorm(embed_dim)
|
| 35 |
+
|
| 36 |
+
def forward(self, coords: torch.Tensor):
|
| 37 |
+
phases = torch.matmul(coords, self.freq_matrix)
|
| 38 |
+
|
| 39 |
+
fourier_features = torch.cat([torch.sin(phases), torch.cos(phases)], -1)
|
| 40 |
+
fourier_emb = self.fourier_linear(fourier_features)
|
| 41 |
+
|
| 42 |
+
learned_emb = self.learned_linear(coords)
|
| 43 |
+
|
| 44 |
+
return self.final_norm(fourier_emb + learned_emb)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TransformerBlock(nn.Module):
|
| 48 |
+
def __init__(self, embed_dim: int, heads: int, dropout: float = 0.0):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
assert embed_dim % heads == 0, "dim must be divisible by heads"
|
| 52 |
+
|
| 53 |
+
self.pre_attn_norm = nn.LayerNorm(embed_dim)
|
| 54 |
+
self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout, batch_first=True)
|
| 55 |
+
|
| 56 |
+
self.pre_ffn_norm = nn.LayerNorm(embed_dim)
|
| 57 |
+
self.ffn = nn.Sequential(nn.Linear(embed_dim, 4 * embed_dim), nn.GELU(), nn.Linear(4 * embed_dim, embed_dim))
|
| 58 |
+
|
| 59 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
+
attn_in = self.pre_attn_norm(x)
|
| 61 |
+
|
| 62 |
+
attn_out, _ = self.attn(attn_in, attn_in, attn_in)
|
| 63 |
+
x = x + attn_out
|
| 64 |
+
|
| 65 |
+
ffn_in = self.pre_ffn_norm(x)
|
| 66 |
+
|
| 67 |
+
ffn_out = self.ffn(ffn_in)
|
| 68 |
+
x = x + ffn_out
|
| 69 |
+
|
| 70 |
+
return x, ffn_out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TransformerEncoderDecoder(nn.Module):
|
| 74 |
+
def __init__(self, embed_dim: int = 512, depth: int = 16, heads: int = 8):
|
| 75 |
+
super().__init__()
|
| 76 |
+
|
| 77 |
+
self.layers = nn.ModuleList([TransformerBlock(embed_dim, heads) for _ in range(depth)])
|
| 78 |
+
self.final_norm = nn.LayerNorm(embed_dim)
|
| 79 |
+
|
| 80 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
| 81 |
+
intermediate = []
|
| 82 |
+
|
| 83 |
+
for layer in self.layers:
|
| 84 |
+
x, ffn_out = layer(x)
|
| 85 |
+
intermediate.append(ffn_out)
|
| 86 |
+
|
| 87 |
+
return self.final_norm(x), intermediate
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class MAEDecoder(nn.Module):
|
| 91 |
+
def __init__(self, embed_dim: int = 512, decoder_depth: int = 4, decoder_heads: int = 8, patch_size: int = 200):
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
# 1. The Mask Token (The "Gray Tile")
|
| 95 |
+
# A learnable vector that replaces every missing patch
|
| 96 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 97 |
+
nn.init.normal_(self.mask_token, std=0.02)
|
| 98 |
+
|
| 99 |
+
# 2. The Decoder Transformer (Reuse your Encoder logic)
|
| 100 |
+
# It's lighter (fewer layers) than the main Encoder
|
| 101 |
+
self.decoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=decoder_depth, heads=decoder_heads)
|
| 102 |
+
|
| 103 |
+
# 3. The Prediction Head
|
| 104 |
+
# Projects Vector (512) -> Raw Signal (200)
|
| 105 |
+
self.predict = nn.Linear(embed_dim, patch_size, bias=True)
|
| 106 |
+
|
| 107 |
+
def forward(self, x_visible: torch.Tensor, pos_enc: nn.Module, coords: torch.Tensor, mask: torch.Tensor):
|
| 108 |
+
B, N_Total, D = coords.shape[0], coords.shape[1], x_visible.shape[-1]
|
| 109 |
+
|
| 110 |
+
# --- Step A: Fill Canvas with Mask Tokens ---
|
| 111 |
+
# Create a tensor of size (Batch, Total, Dim) filled with the mask token
|
| 112 |
+
x_full = self.mask_token.expand(B, N_Total, D).clone()
|
| 113 |
+
|
| 114 |
+
# --- Step B: Paste Visible Tokens ---
|
| 115 |
+
# Overwrite the mask tokens with the actual encoder output at the visible spots
|
| 116 |
+
for i in range(B):
|
| 117 |
+
# We use the boolean mask to select the "True" slots
|
| 118 |
+
x_full[i, mask[i]] = x_visible[i]
|
| 119 |
+
|
| 120 |
+
# --- Step C: Add Positional Encoding ---
|
| 121 |
+
# We call YOUR PosEnc class here.
|
| 122 |
+
# It takes coords (B, N_Total, 4) and returns (B, N_Total, Dim)
|
| 123 |
+
pos_emb = pos_enc(coords)
|
| 124 |
+
|
| 125 |
+
# Add GPS info to the tokens
|
| 126 |
+
x_full = x_full + pos_emb
|
| 127 |
+
|
| 128 |
+
# --- Step D: Decode ---
|
| 129 |
+
# Pass through the Transformer
|
| 130 |
+
# We ignore the intermediate outputs (the second return value) for now
|
| 131 |
+
x_decoded, _ = self.decoder(x_full)
|
| 132 |
+
|
| 133 |
+
# --- Step E: Predict ---
|
| 134 |
+
# (Batch, N_Total, 512) -> (Batch, N_Total, 200)
|
| 135 |
+
prediction = self.predict(x_decoded)
|
| 136 |
+
|
| 137 |
+
return prediction
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def generate_mask(coords: torch.Tensor, mask_ratio: float = 0.55, spatial_radius: float = 3.0, temporal_radius: float = 3.0):
|
| 141 |
+
B, N, _ = coords.shape
|
| 142 |
+
device = coords.device
|
| 143 |
+
|
| 144 |
+
# Calculate exact number of tokens to hide
|
| 145 |
+
num_masked_target = int(mask_ratio * N)
|
| 146 |
+
|
| 147 |
+
# Start with all True (Visible)
|
| 148 |
+
mask = torch.ones(B, N, dtype=torch.bool, device=device)
|
| 149 |
+
|
| 150 |
+
for b in range(B):
|
| 151 |
+
spatial_coords = coords[b, :, :3]
|
| 152 |
+
temporal_coords = coords[b, :, 3]
|
| 153 |
+
|
| 154 |
+
# --- Phase 1: Block Masking Strategy ---
|
| 155 |
+
# Keep masking blocks until we meet or exceed the target
|
| 156 |
+
while (~mask[b]).sum() < num_masked_target:
|
| 157 |
+
# Pick random seed
|
| 158 |
+
seed_idx = torch.randint(0, N, (1,)).item()
|
| 159 |
+
|
| 160 |
+
# Calculate distances
|
| 161 |
+
seed_spatial = spatial_coords[seed_idx]
|
| 162 |
+
dists_spatial = torch.norm(spatial_coords - seed_spatial, dim=1)
|
| 163 |
+
|
| 164 |
+
seed_temporal = temporal_coords[seed_idx]
|
| 165 |
+
dists_temporal = torch.abs(temporal_coords - seed_temporal)
|
| 166 |
+
|
| 167 |
+
# Find block
|
| 168 |
+
in_block = (dists_spatial <= spatial_radius) & (dists_temporal <= temporal_radius)
|
| 169 |
+
|
| 170 |
+
# Mask this block (Set to False)
|
| 171 |
+
mask[b, in_block] = False
|
| 172 |
+
|
| 173 |
+
# --- Phase 2: Exact Count Enforcement ---
|
| 174 |
+
# We likely masked too many tokens. We must unmask the excess.
|
| 175 |
+
|
| 176 |
+
# Get indices of all tokens that are currently masked
|
| 177 |
+
masked_indices = torch.where(mask[b] == False)[0]
|
| 178 |
+
num_current_masked = len(masked_indices)
|
| 179 |
+
|
| 180 |
+
if num_current_masked > num_masked_target:
|
| 181 |
+
# We have excess. Randomly choose which ones to KEEP masked.
|
| 182 |
+
# Shuffle the masked indices
|
| 183 |
+
shuffled_indices = masked_indices[torch.randperm(num_current_masked)]
|
| 184 |
+
|
| 185 |
+
# The first 'num_masked_target' stay masked.
|
| 186 |
+
# The rest (excess) must be turned back to Visible (True).
|
| 187 |
+
excess_indices = shuffled_indices[num_masked_target:]
|
| 188 |
+
|
| 189 |
+
mask[b, excess_indices] = True
|
| 190 |
+
|
| 191 |
+
return mask
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class MAE(nn.Module):
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
# Data Params
|
| 198 |
+
fs: int = 200,
|
| 199 |
+
patch_seconds: float = 1.0,
|
| 200 |
+
overlap_seconds: float = 0.1,
|
| 201 |
+
# Model Params
|
| 202 |
+
embed_dim: int = 512,
|
| 203 |
+
encoder_depth: int = 12,
|
| 204 |
+
encoder_heads: int = 8,
|
| 205 |
+
decoder_depth: int = 4,
|
| 206 |
+
decoder_heads: int = 8,
|
| 207 |
+
# Training Params
|
| 208 |
+
mask_ratio: float = 0.55,
|
| 209 |
+
aux_loss_weight: float = 0.1,
|
| 210 |
+
):
|
| 211 |
+
super().__init__()
|
| 212 |
+
|
| 213 |
+
self.embed_dim = embed_dim
|
| 214 |
+
self.mask_ratio = mask_ratio
|
| 215 |
+
self.aux_loss_weight = aux_loss_weight
|
| 216 |
+
|
| 217 |
+
# 1. Input Processing
|
| 218 |
+
self.patch_embed = PatchEmbed(fs, patch_seconds, overlap_seconds, embed_dim)
|
| 219 |
+
|
| 220 |
+
# We calculate patch_size and step from the component we just initialized
|
| 221 |
+
self.patch_size = self.patch_embed.patch_size
|
| 222 |
+
self.step = self.patch_embed.step
|
| 223 |
+
|
| 224 |
+
# 2. Positional Encoding (Shared between Encoder and Decoder)
|
| 225 |
+
self.pos_enc = PosEnc(n_freqs=4, embed_dim=embed_dim)
|
| 226 |
+
|
| 227 |
+
# 3. Encoder
|
| 228 |
+
self.encoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=encoder_depth, heads=encoder_heads)
|
| 229 |
+
|
| 230 |
+
# 4. Decoder (Main Reconstruction)
|
| 231 |
+
self.decoder = MAEDecoder(embed_dim=embed_dim, decoder_depth=decoder_depth, decoder_heads=decoder_heads, patch_size=self.patch_size)
|
| 232 |
+
|
| 233 |
+
# 5. Auxiliary Head (Global Token)
|
| 234 |
+
# We concatenate outputs from ALL encoder layers
|
| 235 |
+
self.aux_dim = encoder_depth * embed_dim
|
| 236 |
+
|
| 237 |
+
# A learned query vector to look at the encoder outputs
|
| 238 |
+
self.aux_query = nn.Parameter(torch.randn(1, 1, self.aux_dim))
|
| 239 |
+
nn.init.normal_(self.aux_query, std=0.02)
|
| 240 |
+
|
| 241 |
+
# Projection: (Depth * Dim) -> Dim
|
| 242 |
+
self.aux_linear = nn.Linear(self.aux_dim, embed_dim, bias=False)
|
| 243 |
+
|
| 244 |
+
# Reconstruction Head for Aux Task
|
| 245 |
+
self.aux_predict = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, self.patch_size))
|
| 246 |
+
|
| 247 |
+
def prepare_coords(self, xyz: torch.Tensor, num_patches: int):
|
| 248 |
+
B, C, _ = xyz.shape
|
| 249 |
+
device = xyz.device
|
| 250 |
+
|
| 251 |
+
# 2. Generate Time Indices (0, 1, 2, ... P-1)
|
| 252 |
+
time_idx = torch.arange(num_patches, device=device, dtype=torch.float32)
|
| 253 |
+
|
| 254 |
+
# 3. Expand Spatial Coords
|
| 255 |
+
# (B, C, 3) -> (B, C, 1, 3) -> (B, C, P, 3)
|
| 256 |
+
spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1)
|
| 257 |
+
|
| 258 |
+
# 4. Expand Time Coords
|
| 259 |
+
# (P,) -> (1, 1, P, 1) -> (B, C, P, 1)
|
| 260 |
+
time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1)
|
| 261 |
+
|
| 262 |
+
# 5. Concatenate -> (B, C, P, 4)
|
| 263 |
+
coords = torch.cat([spat, time], dim=-1)
|
| 264 |
+
|
| 265 |
+
# 6. Flatten to (B, N_Total, 4)
|
| 266 |
+
return coords.flatten(1, 2)
|
| 267 |
+
|
| 268 |
+
def forward(self, x: torch.Tensor, xyz: torch.Tensor):
|
| 269 |
+
B, _, _ = x.shape
|
| 270 |
+
|
| 271 |
+
# --- 1. Patchify & Embed ---
|
| 272 |
+
# patches: (B, C, P, PatchSize)
|
| 273 |
+
patches = x.unfold(-1, self.patch_size, self.step)
|
| 274 |
+
num_patches = patches.shape[2]
|
| 275 |
+
|
| 276 |
+
# tokens: (B, C, P, Dim)
|
| 277 |
+
tokens = self.patch_embed.linear(patches)
|
| 278 |
+
|
| 279 |
+
# Flatten to Sequence: (B, N_Total, Dim)
|
| 280 |
+
tokens_flat = tokens.flatten(1, 2)
|
| 281 |
+
patches_flat = patches.flatten(1, 2) # Target for loss
|
| 282 |
+
|
| 283 |
+
# --- 2. Prepare 4D Coordinates ---
|
| 284 |
+
coords = self.prepare_coords(xyz, num_patches)
|
| 285 |
+
|
| 286 |
+
# --- 3. Generate Mask ---
|
| 287 |
+
# Returns mask where counts are GUARANTEED to be equal across batch
|
| 288 |
+
mask = generate_mask(coords, mask_ratio=self.mask_ratio)
|
| 289 |
+
|
| 290 |
+
# --- 4. Prepare Encoder Input ---
|
| 291 |
+
# We need to extract only the visible tokens and stack them.
|
| 292 |
+
# Since counts are fixed, we can do this efficiently using boolean masking and reshaping.
|
| 293 |
+
|
| 294 |
+
# tokens_flat: (B, N_Total, D)
|
| 295 |
+
# mask: (B, N_Total)
|
| 296 |
+
# Result: (B, N_Vis, D)
|
| 297 |
+
# The .view() works because the number of Trues in mask is identical for every row b.
|
| 298 |
+
n_vis = mask[0].sum().item()
|
| 299 |
+
|
| 300 |
+
x_vis = tokens_flat[mask].view(B, n_vis, -1)
|
| 301 |
+
coords_vis = coords[mask].view(B, n_vis, -1)
|
| 302 |
+
|
| 303 |
+
# Add PE
|
| 304 |
+
pe_vis = self.pos_enc(coords_vis)
|
| 305 |
+
x_vis = x_vis + pe_vis
|
| 306 |
+
|
| 307 |
+
# --- 5. Encoder Forward ---
|
| 308 |
+
x_encoded, intermediates = self.encoder(x_vis)
|
| 309 |
+
|
| 310 |
+
# --- 6. Main Decoder Path ---
|
| 311 |
+
predictions_main = self.decoder(x_visible=x_encoded, pos_enc=self.pos_enc, coords=coords, mask=mask)
|
| 312 |
+
|
| 313 |
+
# --- 7. Auxiliary Path (Global Token) ---
|
| 314 |
+
# Concatenate all intermediate layers: (B, N_Vis, Depth*Dim)
|
| 315 |
+
aux_input = torch.cat(intermediates, dim=-1)
|
| 316 |
+
|
| 317 |
+
# Attention Pooling
|
| 318 |
+
# Score = Input @ Query.T
|
| 319 |
+
# (B, N_Vis, AuxDim) @ (1, 1, AuxDim).T -> (B, N_Vis, 1)
|
| 320 |
+
attn_scores = torch.matmul(aux_input, self.aux_query.transpose(1, 2))
|
| 321 |
+
attn_weights = F.softmax(attn_scores, dim=1)
|
| 322 |
+
|
| 323 |
+
# Pool: Sum(Weights * Input) -> (B, 1, AuxDim)
|
| 324 |
+
global_token = torch.sum(attn_weights * aux_input, dim=1, keepdim=True)
|
| 325 |
+
|
| 326 |
+
# Project to Embed Dim: (B, 1, Dim)
|
| 327 |
+
global_emb = self.aux_linear(global_token)
|
| 328 |
+
|
| 329 |
+
# Predict Masked Patches
|
| 330 |
+
# 1. Get coords of masked tokens
|
| 331 |
+
# Since mask is fixed count, we can reshape cleanly
|
| 332 |
+
n_masked = (~mask[0]).sum().item()
|
| 333 |
+
coords_masked = coords[~mask].view(B, n_masked, -1)
|
| 334 |
+
|
| 335 |
+
pe_masked = self.pos_enc(coords_masked)
|
| 336 |
+
|
| 337 |
+
# 2. Expand global token
|
| 338 |
+
global_expanded = global_emb.expand(-1, n_masked, -1)
|
| 339 |
+
|
| 340 |
+
# 3. Combine & Predict
|
| 341 |
+
aux_pred_in = global_expanded + pe_masked
|
| 342 |
+
predictions_aux = self.aux_predict(aux_pred_in)
|
| 343 |
+
|
| 344 |
+
# --- 8. Loss Calculation ---
|
| 345 |
+
# Target: Only the masked patches
|
| 346 |
+
target_masked = patches_flat[~mask].view(B, n_masked, -1)
|
| 347 |
+
|
| 348 |
+
# Main Loss (L1 on masked)
|
| 349 |
+
pred_main_masked = predictions_main[~mask].view(B, n_masked, -1)
|
| 350 |
+
loss_main = F.l1_loss(pred_main_masked, target_masked)
|
| 351 |
+
|
| 352 |
+
# Aux Loss (L1 on masked)
|
| 353 |
+
loss_aux = F.l1_loss(predictions_aux, target_masked)
|
| 354 |
+
|
| 355 |
+
total_loss = loss_main + self.aux_loss_weight * loss_aux
|
| 356 |
+
|
| 357 |
+
return total_loss, predictions_main, mask
|
manas1.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb221351143c45e71ed478a5622a1ccf8f140b983a613f6f5875c862ae48ba76
|
| 3 |
+
size 653413200
|
modelclass.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from functools import partial
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from mae import MAE
|
| 8 |
+
|
| 9 |
+
# Use cm positions from mne.get_montage() - these are only included for reference.
|
| 10 |
+
POSITIONS = {
|
| 11 |
+
"Fp1": (-3.09, 11.46, 2.79),
|
| 12 |
+
"Fp2": (2.84, 11.53, 2.77),
|
| 13 |
+
"F3": (-5.18, 8.67, 7.87),
|
| 14 |
+
"F4": (5.03, 8.74, 7.73),
|
| 15 |
+
"F7": (-7.19, 7.31, 2.58),
|
| 16 |
+
"F8": (7.14, 7.45, 2.51),
|
| 17 |
+
"T3": (-8.60, 1.49, 3.12),
|
| 18 |
+
"T4": (8.33, 1.53, 3.10),
|
| 19 |
+
"C3": (-6.71, 2.34, 10.45),
|
| 20 |
+
"C4": (6.53, 2.36, 10.37),
|
| 21 |
+
"T5": (-8.77, 1.29, -0.77),
|
| 22 |
+
"T6": (8.37, 1.17, -0.77),
|
| 23 |
+
"P3": (-5.50, -4.42, 9.99),
|
| 24 |
+
"P4": (5.36, -4.43, 10.05),
|
| 25 |
+
"O1": (-3.16, -8.06, 5.48),
|
| 26 |
+
"O2": (2.77, -8.05, 5.47),
|
| 27 |
+
"Fz": (-0.12, 9.33, 10.26),
|
| 28 |
+
"Cz": (-0.14, 2.76, 14.02),
|
| 29 |
+
"Pz": (-0.17, -4.52, 12.67),
|
| 30 |
+
"A2": (8.39, 0.20, -2.69),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MANAS1(nn.Module):
|
| 35 |
+
def __init__(self, checkpoint_path, num_classes=2, flat_dim=512):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
print(f"Loading checkpoint from {checkpoint_path}...")
|
| 39 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
| 40 |
+
|
| 41 |
+
self.mae = MAE(fs=200, embed_dim=512, encoder_depth=12, encoder_heads=8, decoder_depth=4, decoder_heads=8, mask_ratio=0.55)
|
| 42 |
+
self.mae.load_state_dict(ckpt["model_state_dict"])
|
| 43 |
+
|
| 44 |
+
self.patch_embed = self.mae.patch_embed
|
| 45 |
+
self.pos_enc = self.mae.pos_enc
|
| 46 |
+
self.encoder = self.mae.encoder
|
| 47 |
+
self.patch_size = self.mae.patch_size
|
| 48 |
+
self.step = self.mae.step
|
| 49 |
+
|
| 50 |
+
self.flat_dim = flat_dim
|
| 51 |
+
|
| 52 |
+
# # The Head
|
| 53 |
+
# self.final_layer = nn.Sequential(
|
| 54 |
+
# nn.Flatten(),
|
| 55 |
+
# nn.RMSNorm(self.flat_dim), # Tutorial uses RMSNorm
|
| 56 |
+
# nn.Dropout(0.1),
|
| 57 |
+
# nn.Linear(self.flat_dim, num_classes),
|
| 58 |
+
# )
|
| 59 |
+
|
| 60 |
+
def prepare_coords(self, xyz, num_patches):
|
| 61 |
+
B, C, _ = xyz.shape
|
| 62 |
+
device = xyz.device
|
| 63 |
+
time_idx = torch.arange(num_patches, device=device).float()
|
| 64 |
+
spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1)
|
| 65 |
+
time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1)
|
| 66 |
+
return torch.cat([spat, time], dim=-1).flatten(1, 2)
|
| 67 |
+
|
| 68 |
+
def forward(self, x, pos):
|
| 69 |
+
patches = x.unfold(-1, self.patch_size, self.step)
|
| 70 |
+
num_patches = patches.shape[2]
|
| 71 |
+
|
| 72 |
+
tokens = self.patch_embed.linear(patches).flatten(1, 2)
|
| 73 |
+
|
| 74 |
+
coords = self.prepare_coords(pos, num_patches)
|
| 75 |
+
pe = self.pos_enc(coords)
|
| 76 |
+
|
| 77 |
+
x_enc = tokens + pe
|
| 78 |
+
latents, _ = self.encoder(x_enc)
|
| 79 |
+
|
| 80 |
+
# add final layer for classification
|
| 81 |
+
return latents
|