| """ |
| model.py β Specialised HΞ± encoder (Stage 1 HΞ±) with MAE. |
| |
| Compact version of SpectralEncoder adapted for spectra cropped |
| to 128 bins around HΞ± (6512.8β6612.8 Γ
). |
| |
| Architecture : |
| Spectre HΞ± [128] β Patches [31, 8] β Projection [31, 128] |
| β + Wavelength PE β Masquage 60% β [CLS] + Visibles [~13, 128] |
| β Transformer 4L β CLS embedding z_halpha [128] |
| β MAE Decoder [reconstruction] |
| |
| Key differences from the full encoder (stage1/model.py): |
| - Input : 128 bins (vs 4096) |
| - Patches : 8 px, overlap 4, step 4 β 31 patches (vs 511) |
| - d_model : 128 (vs 256) |
| - n_layers : 4 (vs 6) |
| - No GRL/discriminator (negative result confirmed) |
| - ~300K params (vs ~5.3M) |
| """ |
|
|
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from config import ModelConfig |
|
|
|
|
| |
| |
| |
|
|
| def contiguous_masking(n_patches: int, mask_ratio: float = 0.60, |
| n_blocks: int = 3) -> np.ndarray: |
| """ |
| Generate a contiguous-block mask for MAE. |
| |
| Adapted for 31 patches (vs 511 for the full encoder): |
| - 60% masking β ~19 masked patches, ~12 visible |
| - 3 contiguous blocks (vs 4) because the sequence is short |
| |
| Avec 12 patches visibles + 1 CLS = 13 tokens pour l'encodeur. |
| C'est suffisant pour un Transformer 4 couches. |
| """ |
| n_masked = int(n_patches * mask_ratio) |
|
|
| if n_masked == 0 or n_patches < n_blocks: |
| return np.zeros(n_patches, dtype=bool) |
|
|
| mask = np.zeros(n_patches, dtype=bool) |
| block_size = max(1, n_masked // n_blocks) |
|
|
| possible_starts = np.arange(0, max(1, n_patches - block_size)) |
|
|
| if len(possible_starts) < n_blocks: |
| starts = possible_starts |
| else: |
| starts = np.sort( |
| np.random.choice(possible_starts, size=n_blocks, replace=False) |
| ) |
|
|
| for s in starts: |
| end = min(s + block_size, n_patches) |
| mask[s:end] = True |
|
|
| |
| current = mask.sum() |
| if current < n_masked: |
| unmasked = np.where(~mask)[0] |
| extra = min(n_masked - current, len(unmasked)) |
| if extra > 0: |
| chosen = np.random.choice(unmasked, size=extra, replace=False) |
| mask[chosen] = True |
|
|
| return mask |
|
|
|
|
| |
| |
| |
|
|
| class WavelengthPE(nn.Module): |
| """ |
| Sinusoidal positional encoding based on wavelength (Γ
). |
| |
| Identique au full encoder : PE(Ξ») = sin/cos(Ξ» / 10000 Γ div_term). |
| |
| Even over a 100 Γ
window, the relative position is physically |
| significative : les ailes bleue (Ξ» < 6562.8) et rouge (Ξ» > 6562.8) |
| of HΞ± have different physical meanings (V/R ratio, asymmetries). |
| """ |
|
|
| def __init__(self, d_model: int = 128): |
| super().__init__() |
| self.d_model = d_model |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| ) |
| self.register_buffer("div_term", div_term) |
|
|
| def forward(self, lambda_means: torch.Tensor) -> torch.Tensor: |
| """ |
| lambda_means: [B, N_patches] β Ξ» moyen de chaque patch (en Γ
). |
| Retourne: [B, N_patches, d_model]. |
| """ |
| pos = lambda_means.unsqueeze(-1) / 10000.0 |
| pe = torch.zeros(*lambda_means.shape, self.d_model, |
| device=lambda_means.device) |
| pe[..., 0::2] = torch.sin(pos * self.div_term) |
| pe[..., 1::2] = torch.cos(pos * self.div_term) |
| return pe |
|
|
|
|
| |
| |
| |
|
|
| class MAEDecoder(nn.Module): |
| """ |
| Lightweight MAE decoder for reconstructing masked HΞ± patches. |
| |
| Smaller than the full encoder decoder: |
| - d_decoder = 64 (vs 128) |
| - n_layers = 2 |
| - head projette vers patch_size = 8 (vs 16) |
| """ |
|
|
| def __init__(self, d_encoder: int, d_decoder: int, n_layers: int, |
| n_heads: int, patch_size: int): |
| super().__init__() |
| self.d_decoder = d_decoder |
|
|
| |
| self.proj = nn.Linear(d_encoder, d_decoder) |
|
|
| |
| self.mask_token = nn.Parameter(torch.randn(1, 1, d_decoder) * 0.02) |
|
|
| |
| decoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_decoder, nhead=n_heads, |
| dim_feedforward=d_decoder * 4, |
| activation="gelu", batch_first=True, dropout=0.1, |
| ) |
| self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=n_layers) |
|
|
| |
| self.head = nn.Linear(d_decoder, patch_size) |
|
|
| def forward(self, encoded_visible, visible_pe, full_pe, mask): |
| """ |
| Reconstruct masked patches. |
| |
| Inputs: |
| encoded_visible : [B, N_vis, d_encoder] |
| visible_pe : [B, N_vis, d_decoder] |
| full_pe : [B, N_all, d_decoder] |
| mask : [B, N_all] β True = masked |
| |
| Output: |
| [B, N_all, patch_size] |
| """ |
| B, N_all = mask.shape |
|
|
| vis = self.proj(encoded_visible) |
|
|
| full_seq = self.mask_token.expand(B, N_all, -1).clone() |
| vis_positions = (~mask) |
| full_seq[vis_positions] = vis.reshape(-1, self.d_decoder) |
|
|
| full_seq = full_seq + full_pe |
|
|
| decoded = self.decoder(full_seq) |
| return self.head(decoded) |
|
|
|
|
| |
| |
| |
|
|
| class SpectralEncoderHalpha(nn.Module): |
| """ |
| MAE Transformer encoder for cropped HΞ± spectra (128 bins). |
| |
| Pipeline interne : |
| 1. Patchify : split into 31 patches of 8 pixels (step=4) |
| 2. Projection : chaque patch [8] β token [128] |
| 3. PE : positional encoding based on Ξ» |
| 4. Masquage : ne garde que ~40% des patches (12 visibles) |
| 5. [CLS] : global summary token |
| 6. Transformer : 4 couches d'auto-attention |
| 7. Sortie : z_halpha = LayerNorm(CLS) β β^128 |
| |
| Avec 128 bins au lieu de 4096 : |
| - 31 patches au lieu de 511 |
| - 12 visibles au lieu de 153 |
| - ~10Γ faster to run |
| """ |
|
|
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.patch_size = cfg.patch_size |
| self.patch_overlap = cfg.patch_overlap |
| self.d_model = cfg.d_model |
| self.step = cfg.patch_size - cfg.patch_overlap |
|
|
| |
| self.patch_proj = nn.Linear(cfg.patch_size, cfg.d_model) |
|
|
| |
| self.wave_pe = WavelengthPE(cfg.d_model) |
|
|
| |
| self.cls_token = nn.Parameter(torch.randn(1, 1, cfg.d_model) * 0.02) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=cfg.d_model, |
| nhead=cfg.n_heads, |
| dim_feedforward=cfg.d_ff, |
| dropout=cfg.dropout, |
| activation="gelu", |
| batch_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer, num_layers=cfg.n_layers |
| ) |
|
|
| |
| self.norm = nn.LayerNorm(cfg.d_model) |
|
|
| def patchify(self, flux, wavelengths, validity): |
| """ |
| Split the 128-bin spectrum into 31 patches of 8 pixels. |
| |
| Avec step=4 : N = (128 - 8) / 4 + 1 = 31 patches |
| """ |
| patches = flux.unfold(-1, self.patch_size, self.step) |
| lam_patches = wavelengths.unfold(-1, self.patch_size, self.step) |
| val_patches = validity.unfold(-1, self.patch_size, self.step) |
|
|
| lambda_means = lam_patches.mean(-1) |
| patch_valid = val_patches.mean(-1) |
|
|
| return patches, lambda_means, patch_valid |
|
|
| def forward(self, flux, wavelengths, validity, mask=None): |
| """ |
| Forward pass. |
| |
| Inputs: |
| flux : [B, 128] |
| wavelengths : [B, 128] |
| validity : [B, 128] |
| mask : [B, 31] optional |
| |
| Outputs: |
| z : [B, d_model] β CLS embedding (z_halpha) |
| encoded : [B, N_vis, d_model] β encoded visible tokens |
| patches : [B, 31, 8] β all patches |
| lambda_means : [B, 31] β mean Ξ» per patch |
| mask : [B, 31] β mask used |
| wpe : [B, 31, d_model] β positional encoding |
| """ |
| B = flux.shape[0] |
|
|
| |
| patches, lambda_means, patch_valid = self.patchify( |
| flux, wavelengths, validity |
| ) |
| N = patches.shape[1] |
|
|
| |
| tokens = self.patch_proj(patches) |
| wpe = self.wave_pe(lambda_means) |
| tokens = tokens + wpe |
|
|
| |
| if mask is None: |
| mask = torch.zeros(B, N, dtype=torch.bool, device=flux.device) |
|
|
| visible_mask = ~mask |
| n_visible = visible_mask[0].sum().item() |
|
|
| visible_tokens = tokens[visible_mask].view(B, n_visible, self.d_model) |
| visible_pe = wpe[visible_mask].view(B, n_visible, self.d_model) |
|
|
| |
| cls = self.cls_token.expand(B, -1, -1) |
| input_tokens = torch.cat([cls, visible_tokens], dim=1) |
|
|
| |
| visible_valid = patch_valid[visible_mask].view(B, n_visible) |
| attn_pad = torch.cat([ |
| torch.ones(B, 1, device=flux.device), |
| (visible_valid > 0.1).float() |
| ], dim=1) |
| src_key_padding_mask = (attn_pad == 0) |
|
|
| |
| encoded = self.encoder( |
| input_tokens, src_key_padding_mask=src_key_padding_mask |
| ) |
|
|
| |
| z = self.norm(encoded[:, 0]) |
|
|
| return z, encoded[:, 1:], patches, lambda_means, mask, wpe |
|
|
|
|
| |
| |
| |
|
|
| class Stage1HalphaModel(nn.Module): |
| """ |
| Assembly: HΞ± Encoder + MAE Decoder. |
| |
| No instrument discriminator (GRL removed). |
| The only loss is the MAE reconstruction of masked patches. |
| """ |
|
|
| def __init__(self, model_cfg: ModelConfig): |
| super().__init__() |
|
|
| self.encoder = SpectralEncoderHalpha(model_cfg) |
|
|
| self.mae_decoder = MAEDecoder( |
| d_encoder=model_cfg.d_model, |
| d_decoder=model_cfg.d_decoder, |
| n_layers=model_cfg.n_decoder_layers, |
| n_heads=model_cfg.n_decoder_heads, |
| patch_size=model_cfg.patch_size, |
| ) |
|
|
| |
| self.pe_proj = nn.Linear(model_cfg.d_model, model_cfg.d_decoder) |
|
|
| def forward(self, flux, wavelengths, validity, mask): |
| """ |
| Forward pass : encoder β reconstruction MAE. |
| |
| Inputs: |
| flux : [B, 128] |
| wavelengths : [B, 128] |
| validity : [B, 128] |
| mask : [B, 31] β MAE mask |
| |
| Returns: |
| z : [B, 128] β CLS embedding (z_halpha) |
| mae_loss : scalar β MSE reconstruction of masked patches |
| reconstructed : [B, 31, 8] β patches reconstruits |
| patches : [B, 31, 8] β patches originaux (cible) |
| mask : [B, 31] β mask used |
| """ |
| |
| z, encoded_vis, patches, lambda_means, mask, wpe = self.encoder( |
| flux, wavelengths, validity, mask |
| ) |
|
|
| B, N, P = patches.shape |
|
|
| |
| full_pe = self.pe_proj(wpe) |
| vis_mask = ~mask |
| n_vis = vis_mask[0].sum().item() |
| vis_pe = full_pe[vis_mask].view(B, n_vis, -1) |
|
|
| |
| reconstructed = self.mae_decoder(encoded_vis, vis_pe, full_pe, mask) |
|
|
| |
| target = patches |
| mae_loss = ((reconstructed - target) ** 2) |
| mask_expanded = mask.unsqueeze(-1).expand_as(mae_loss) |
| n_masked_total = mask_expanded.sum() |
|
|
| if n_masked_total > 0: |
| mae_loss = (mae_loss * mask_expanded.float()).sum() / n_masked_total |
| else: |
| mae_loss = mae_loss.mean() |
|
|
| return { |
| "z": z, |
| "mae_loss": mae_loss, |
| "reconstructed": reconstructed, |
| "patches": patches, |
| "mask": mask, |
| } |
|
|
| def get_embeddings(self, flux, wavelengths, validity): |
| """ |
| Inference mode: CLS embeddings without masking. |
| Retourne z_halpha β β^128. |
| """ |
| z, _, _, _, _, _ = self.encoder(flux, wavelengths, validity, mask=None) |
| return z |
|
|