| import os |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchinterp1d import interp1d |
| import torch.optim as optim |
| import matplotlib.pyplot as plt |
| import os |
| from tqdm import tqdm |
| from dataset import desi |
|
|
| from scipy.ndimage import gaussian_filter1d |
| from astropy.convolution import convolve, Gaussian1DKernel |
| import numpy as np |
| import logging |
|
|
| class PatchEmbed1D(nn.Module): |
|
|
| def __init__(self, patch_size=16, d_model=512): |
| super().__init__() |
| self.patch_size = patch_size |
| self.proj = nn.Linear(patch_size, d_model) |
|
|
| def forward(self, x): |
| B, L = x.shape |
| pad_len = (math.ceil(L / self.patch_size) * self.patch_size) - L |
| if pad_len > 0: |
| x = F.pad(x, (0, pad_len)) |
|
|
| x = x.view(B, -1, self.patch_size) |
| tokens = self.proj(x) |
|
|
| return tokens, L + pad_len |
|
|
|
|
| class ReconTransformer(nn.Module): |
| def __init__(self, |
| orig_length=7781, |
| target_length=9780, |
| patch_size=16, |
| d_model=512, |
| nhead=8, |
| num_layers=6, |
| dim_feedforward=2048, |
| dropout=0.1, |
| use_z_cond=True): |
| super().__init__() |
|
|
| |
| self.orig_length = orig_length |
| self.target_length = target_length |
| self.patch_size = patch_size |
| self.d_model = d_model |
| self.use_z_cond = use_z_cond |
|
|
| self.patch_embed = PatchEmbed1D(patch_size=patch_size, d_model=d_model) |
|
|
| max_patches = math.ceil(orig_length / patch_size) |
| self.pos_embed = nn.Parameter(torch.randn(1, max_patches, d_model) * 0.02) |
|
|
| if use_z_cond: |
| self.z_proj = nn.Sequential( |
| nn.Linear(1, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model) |
| ) |
| else: |
| self.z_proj = None |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, |
| activation='gelu', |
| batch_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
| self.recon_head = nn.Linear(d_model, patch_size) |
| self.norm = nn.LayerNorm(d_model) |
| self._init_weights() |
| self.out_patches = math.ceil(self.target_length / patch_size) |
|
|
| def _init_weights(self): |
|
|
| nn.init.xavier_uniform_(self.recon_head.weight) |
| if self.recon_head.bias is not None: |
| nn.init.zeros_(self.recon_head.bias) |
|
|
| def forward(self, spec, z=None): |
|
|
| B, L = spec.shape |
| tokens, padded_len = self.patch_embed(spec) |
| N_in = tokens.shape[1] |
|
|
| pos = self.pos_embed[:, :N_in, :] |
| tokens = tokens + pos |
|
|
| if self.use_z_cond and z is not None: |
| if z.dim() == 1: |
| z = z.unsqueeze(-1) |
| z_emb = self.z_proj(z) |
| tokens = tokens + z_emb.unsqueeze(1) |
|
|
| tokens = self.norm(tokens) |
| enc_out = self.encoder(tokens) |
| |
| N_out = self.out_patches |
| if N_out != N_in: |
| enc_out = F.interpolate( |
| enc_out.transpose(1, 2), |
| size=N_out, |
| mode="linear", |
| align_corners=False |
| ).transpose(1, 2) |
|
|
| patches = self.recon_head(enc_out) |
| recon = patches.reshape(B, N_out * self.patch_size) |
|
|
| recon = recon[:, :self.target_length] |
|
|
| if z is None: |
| z = torch.zeros(len(recon)) |
| |
| B = recon.size(0) |
|
|
| if z.dim() == 1: |
| z = z.unsqueeze(1) |
| |
| return recon |
| |
| def extract_features(self, spec, z=None, pool="mean"): |
|
|
| B, L = spec.shape |
| tokens, padded_len = self.patch_embed(spec) |
| N_in = tokens.shape[1] |
|
|
| pos = self.pos_embed[:, :N_in, :] |
| tokens = tokens + pos |
|
|
| if self.use_z_cond and z is not None: |
| if z.dim() == 1: |
| z = z.unsqueeze(-1) |
| z_emb = self.z_proj(z) |
| tokens = tokens + z_emb.unsqueeze(1) |
|
|
| tokens = self.norm(tokens) |
| enc_out = self.encoder(tokens) |
|
|
| if pool == "mean": |
| feat = enc_out.mean(dim=1) |
| elif pool == "cls": |
| feat = enc_out[:, 0, :] |
| elif pool == "flatten": |
| feat = enc_out.reshape(B, -1) |
| raise ValueError("Unknown pooling method") |
| |
| return feat |
|
|
|
|
| def to_rest_frame(spec_obs, z, wave_obs_min=3600.0, wave_obs_max=9824.0, n_obs=7781, n_rest=9780, device="cuda"): |
|
|
| B = spec_obs.size(0) |
|
|
| wave_obs = torch.linspace(wave_obs_min, wave_obs_max, n_obs, |
| dtype=torch.float32, device=device) |
| wave_obs_batch = wave_obs.unsqueeze(0).expand(B, -1) |
|
|
| lmbda_min = wave_obs_min / (1.0 + 0.8) |
| lmbda_max = wave_obs_max / 1.0 |
|
|
| wave_rest = torch.linspace(lmbda_min, lmbda_max, n_rest, |
| dtype=torch.float32, device=device) |
| |
| wave_rest_batch = wave_rest.unsqueeze(0).expand(B, -1) |
| wave_redshifted = wave_rest_batch * (1 + z.unsqueeze(-1)) |
|
|
| spec_rest = interp1d(wave_obs_batch, spec_obs, wave_redshifted) |
| valid = (wave_redshifted >= wave_obs_min) & (wave_redshifted <= wave_obs_max) |
| spec_rest[~valid] = 0.0 |
|
|
| return spec_rest, wave_rest |
|
|
| def weighted_mse_loss(recon, target, weight): |
|
|
| diff = (recon - target) ** 2 |
| weighted = diff * weight |
| denom = weight.sum() |
| if denom == 0: |
| return diff.mean() |
| return weighted.sum() / denom |
|
|
| def consistency_loss(sf, sf_aug, individual=False): |
| batch_size, s_size = sf.shape |
| x = torch.sum((sf_aug - sf)**2/(0.5)**2,dim=1)/s_size |
| sim_loss = torch.sigmoid(x)-0.5 |
| if individual: |
| return x, sim_loss |
| return sim_loss.sum() |
|
|
|
|
| def train_model(desi, model, trainloader, validloader, device, epochs=20, lr=1e-4, weight_decay=1e-5, |
| save_path="best_recon_transformer.pth", plot_interval=1, fig_dir="figures", |
| log_interval=5000, log_file="training.log"): |
| |
| |
| os.makedirs(fig_dir, exist_ok=True) |
| os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler(log_file, mode='w'), |
| logging.StreamHandler() |
| ] |
| ) |
| logger = logging.getLogger() |
|
|
| model.to(device) |
| optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
| best_val_loss = float("inf") |
| fixed_valid_batch = None |
| |
| iter_loss_accums, iter_mse_accums, iter_consistency_accums = [], [], [] |
| train_losses, val_losses = [], [] |
|
|
| for epoch in range(1, epochs + 1): |
| model.train() |
| train_loss, num_samples = 0.0, 0 |
| iter_count = 0 |
| iter_loss_accum = iter_mse_accum = iter_consistency_accum = 0.0 |
|
|
| for spec, w, z, _, _, _, _, _ in tqdm(trainloader, desc=f"Epoch {epoch} [Train]", leave=False): |
| |
| B = spec.size(0) |
| |
| flux_smooth = gaussian_filter1d(spec.cpu().detach().numpy(), sigma=5, axis=1) |
| kernel = Gaussian1DKernel(stddev=5, x_size=5) |
| flux_conv_batch = np.array([convolve(f, kernel) for f in flux_smooth]) |
|
|
| spec = torch.tensor(flux_conv_batch, device=device, dtype=torch.float32) |
| spec, w, z = spec.to(device).float(), w.to(device).float(), z.to(device).float() |
|
|
| rest_spec = to_rest_frame(spec, z, device=device)[0] |
| w_spec = to_rest_frame(w, z, device=device)[0] |
|
|
| optimizer.zero_grad() |
| recon = model(spec, z) |
| mse_loss = weighted_mse_loss(recon, rest_spec, w_spec) |
|
|
| spec_aug, w_aug, z_aug = desi.augment_spectra(spec, w, z) |
| specf_aug = model.extract_features(spec_aug, z_aug) |
| specf = model.extract_features(spec, z) |
| consistency_loss_value = consistency_loss(specf, specf_aug) |
|
|
| loss = mse_loss + consistency_loss_value |
| |
| loss.backward() |
| optimizer.step() |
|
|
| train_loss += loss.item() * B |
| num_samples += B |
| iter_count += 1 |
|
|
| iter_loss_accum += loss.item() |
| iter_mse_accum += mse_loss.item() |
| iter_consistency_accum += consistency_loss_value.item() |
|
|
| if iter_count % log_interval == 0: |
| avg_iter_loss = iter_loss_accum / log_interval |
| avg_iter_mse = iter_mse_accum / log_interval |
| avg_iter_consistency = iter_consistency_accum / log_interval |
|
|
| iter_loss_accums.append(avg_iter_loss) |
| iter_mse_accums.append(avg_iter_mse) |
| iter_consistency_accums.append(avg_iter_consistency) |
|
|
| |
| logger.info( |
| f"Epoch {epoch} Iter {iter_count}: " |
| f"Avg Total={avg_iter_loss:.6f}, " |
| f"MSE={avg_iter_mse:.6f}, " |
| f"Consistency={avg_iter_consistency:.6f}" |
| ) |
|
|
| |
| fig_loss, axes = plt.subplots(3, 1, figsize=(10, 12), sharex=True) |
| axes[0].plot(iter_loss_accums, color='blue') |
| axes[0].set_ylabel('Total Loss'); axes[0].grid(True) |
| axes[1].plot(iter_mse_accums, color='green') |
| axes[1].set_ylabel('MSE Loss'); axes[1].grid(True) |
| axes[2].plot(iter_consistency_accums, color='orange') |
| axes[2].set_ylabel('Consistency Loss') |
| axes[2].set_xlabel('Iteration'); axes[2].grid(True) |
| fig_loss.suptitle('Training Loss Components') |
| plt.tight_layout(rect=[0, 0, 1, 0.96]) |
| fig_loss.savefig(os.path.join(fig_dir, "loss_curve_iter.png")) |
| plt.close(fig_loss) |
|
|
| |
| for i in range(min(4, B)): |
| plt.figure(figsize=(12, 3)) |
| plt.subplot(2,1,1) |
| plt.plot(rest_spec[i].cpu().numpy(), label="Input", color='blue', linewidth=0.5, alpha=0.7) |
| plt.plot(recon[i].detach().cpu().numpy(), label="Reconstructed", color='red', linewidth=0.5, alpha=0.7) |
| plt.title(f"Epoch {epoch} Iter {iter_count} Sample {i} | z = {z[i].item():.3f}") |
| plt.legend(); plt.grid(True); plt.tight_layout() |
| plt.subplot(2,1,2) |
| plt.plot(w_spec[i].detach().cpu().numpy(), label="w", color='black', linewidth=0.5, alpha=0.7) |
| plt.legend(); plt.grid(True); plt.tight_layout() |
| plt.savefig(os.path.join(fig_dir, f"recon_epoch{epoch}_iter{iter_count}_sample{i}.png")) |
| plt.close() |
|
|
| iter_loss_accum = iter_mse_accum = iter_consistency_accum = 0.0 |
|
|
| train_loss /= num_samples |
| train_losses.append(train_loss) |
|
|
| |
| model.eval() |
| val_loss, num_samples = 0.0, 0 |
| with torch.no_grad(): |
| for i, (spec, w, z, _, _, _, _, _) in enumerate(tqdm(validloader, desc=f"Epoch {epoch} [Valid]", leave=False)): |
| spec, w, z = spec.to(device).float(), w.to(device).float(), z.to(device).float() |
| rest_spec = to_rest_frame(spec, z, device=device)[0] |
| w_spec = to_rest_frame(w, z, device=device)[0] |
| recon = model(spec, z) |
| loss = weighted_mse_loss(recon, rest_spec, w_spec) |
| val_loss += loss.item() * spec.size(0) |
| num_samples += spec.size(0) |
|
|
| if fixed_valid_batch is None and i == 0: |
| fixed_valid_batch = (rest_spec.cpu(), recon.cpu(), z.cpu()) |
|
|
| val_loss /= num_samples |
| val_losses.append(val_loss) |
|
|
| |
| logger.info(f"Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}") |
|
|
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| save_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() |
| torch.save(save_dict, save_path) |
| logger.info(f" ✅ New best model saved to {save_path}") |
|
|
| |
| if epoch % plot_interval == 0 and fixed_valid_batch is not None: |
| spec_fixed, recon_fixed, z_fixed = fixed_valid_batch |
| B_vis = min(4, spec_fixed.size(0)) |
| fig_recon, axes = plt.subplots(B_vis, 1, figsize=(12, 2 * B_vis)) |
| if B_vis == 1: axes = [axes] |
| for i in range(B_vis): |
| ax = axes[i] |
| ax.plot(spec_fixed[i].numpy(), label="Input", color='blue', linewidth=0.5, alpha=0.7) |
| ax.plot(recon_fixed[i].numpy(), label="Reconstructed", color='red', linewidth=0.5, alpha=0.7) |
| ax.set_title(f"z = {z_fixed[i].item():.3f}") |
| ax.legend(); ax.grid(True) |
| plt.tight_layout() |
| plt.savefig(os.path.join(fig_dir, f"recon_epoch_{epoch:03d}.png")) |
| plt.close(fig_recon) |
|
|
| logger.info("\n🎉 Training completed.") |
| logger.info(f"Best Validation Loss: {best_val_loss:.6f}") |
|
|
| |
| if __name__ == "__main__": |
|
|
| device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") |
|
|
| desi = desi.DESI() |
|
|
| CSV_PATH = '/home/data/multimodalUniverse/DESIspec_make/matched_sources_by_targetid_inner.csv' |
|
|
| trainloader = desi.get_data_loader( |
| csv_path=CSV_PATH, |
| which="train", |
| batch_size=64, |
| shuffle=True |
| ) |
|
|
| validloader = desi.get_data_loader( |
| csv_path=CSV_PATH, |
| which="test", |
| batch_size=64, |
| shuffle=False |
| ) |
|
|
| model = ReconTransformer( |
| orig_length=7781, |
| target_length=9780, |
| patch_size=4, |
| d_model=512, |
| nhead=8, |
| num_layers=6, |
| dim_feedforward=2048, |
| dropout=0.1, |
| use_z_cond=True |
| ) |
|
|
| model.to(device) |
|
|
| train_model( desi, model=model, |
| trainloader=trainloader, |
| validloader=validloader, |
| device=device, |
| epochs=1000, |
| lr=1e-4, |
| weight_decay=1e-5, |
| save_path="./training_plots_reshift/best_recon_transformer_acc.pth", |
| plot_interval=1, |
| fig_dir="training_plots_reshift") |
|
|