import os import math import torch import torch.nn as nn import torch.nn.functional as F from torchinterp1d import interp1d import torch.optim as optim import matplotlib.pyplot as plt import os from tqdm import tqdm from dataset import desi from scipy.ndimage import gaussian_filter1d from astropy.convolution import convolve, Gaussian1DKernel import numpy as np import logging class PatchEmbed1D(nn.Module): def __init__(self, patch_size=16, d_model=512): super().__init__() self.patch_size = patch_size self.proj = nn.Linear(patch_size, d_model) def forward(self, x): B, L = x.shape pad_len = (math.ceil(L / self.patch_size) * self.patch_size) - L if pad_len > 0: x = F.pad(x, (0, pad_len)) x = x.view(B, -1, self.patch_size) tokens = self.proj(x) return tokens, L + pad_len class ReconTransformer(nn.Module): def __init__(self, orig_length=7781, target_length=9780, patch_size=16, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, use_z_cond=True): super().__init__() self.orig_length = orig_length self.target_length = target_length self.patch_size = patch_size self.d_model = d_model self.use_z_cond = use_z_cond self.patch_embed = PatchEmbed1D(patch_size=patch_size, d_model=d_model) max_patches = math.ceil(orig_length / patch_size) self.pos_embed = nn.Parameter(torch.randn(1, max_patches, d_model) * 0.02) if use_z_cond: self.z_proj = nn.Sequential( nn.Linear(1, d_model), nn.GELU(), nn.Linear(d_model, d_model) ) else: self.z_proj = None encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.recon_head = nn.Linear(d_model, patch_size) self.norm = nn.LayerNorm(d_model) self._init_weights() self.out_patches = math.ceil(self.target_length / patch_size) def _init_weights(self): nn.init.xavier_uniform_(self.recon_head.weight) if self.recon_head.bias is not None: nn.init.zeros_(self.recon_head.bias) def forward(self, spec, z=None): B, L = spec.shape tokens, padded_len = self.patch_embed(spec) N_in = tokens.shape[1] pos = self.pos_embed[:, :N_in, :] tokens = tokens + pos if self.use_z_cond and z is not None: if z.dim() == 1: z = z.unsqueeze(-1) z_emb = self.z_proj(z) tokens = tokens + z_emb.unsqueeze(1) tokens = self.norm(tokens) enc_out = self.encoder(tokens) N_out = self.out_patches if N_out != N_in: enc_out = F.interpolate( enc_out.transpose(1, 2), size=N_out, mode="linear", align_corners=False ).transpose(1, 2) patches = self.recon_head(enc_out) recon = patches.reshape(B, N_out * self.patch_size) recon = recon[:, :self.target_length] if z is None: z = torch.zeros(len(recon)) B = recon.size(0) if z.dim() == 1: z = z.unsqueeze(1) return recon def extract_features(self, spec, z=None, pool="mean"): B, L = spec.shape tokens, padded_len = self.patch_embed(spec) N_in = tokens.shape[1] pos = self.pos_embed[:, :N_in, :] tokens = tokens + pos if self.use_z_cond and z is not None: if z.dim() == 1: z = z.unsqueeze(-1) z_emb = self.z_proj(z) tokens = tokens + z_emb.unsqueeze(1) tokens = self.norm(tokens) enc_out = self.encoder(tokens) if pool == "mean": feat = enc_out.mean(dim=1) elif pool == "cls": feat = enc_out[:, 0, :] elif pool == "flatten": feat = enc_out.reshape(B, -1) raise ValueError("Unknown pooling method") return feat def to_rest_frame(spec_obs, z, wave_obs_min=3600.0, wave_obs_max=9824.0, n_obs=7781, n_rest=9780, device="cuda"): B = spec_obs.size(0) wave_obs = torch.linspace(wave_obs_min, wave_obs_max, n_obs, dtype=torch.float32, device=device) wave_obs_batch = wave_obs.unsqueeze(0).expand(B, -1) lmbda_min = wave_obs_min / (1.0 + 0.8) lmbda_max = wave_obs_max / 1.0 wave_rest = torch.linspace(lmbda_min, lmbda_max, n_rest, dtype=torch.float32, device=device) wave_rest_batch = wave_rest.unsqueeze(0).expand(B, -1) wave_redshifted = wave_rest_batch * (1 + z.unsqueeze(-1)) spec_rest = interp1d(wave_obs_batch, spec_obs, wave_redshifted) valid = (wave_redshifted >= wave_obs_min) & (wave_redshifted <= wave_obs_max) spec_rest[~valid] = 0.0 return spec_rest, wave_rest def weighted_mse_loss(recon, target, weight): diff = (recon - target) ** 2 weighted = diff * weight denom = weight.sum() if denom == 0: return diff.mean() return weighted.sum() / denom def consistency_loss(sf, sf_aug, individual=False): batch_size, s_size = sf.shape x = torch.sum((sf_aug - sf)**2/(0.5)**2,dim=1)/s_size sim_loss = torch.sigmoid(x)-0.5 if individual: return x, sim_loss return sim_loss.sum() def train_model(desi, model, trainloader, validloader, device, epochs=20, lr=1e-4, weight_decay=1e-5, save_path="best_recon_transformer.pth", plot_interval=1, fig_dir="figures", log_interval=5000, log_file="training.log"): # === 新增:配置 logging === os.makedirs(fig_dir, exist_ok=True) os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file, mode='w'), # 写入文件(覆盖模式) logging.StreamHandler() # 同时输出到控制台 ] ) logger = logging.getLogger() model.to(device) optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) best_val_loss = float("inf") fixed_valid_batch = None iter_loss_accums, iter_mse_accums, iter_consistency_accums = [], [], [] train_losses, val_losses = [], [] for epoch in range(1, epochs + 1): model.train() train_loss, num_samples = 0.0, 0 iter_count = 0 iter_loss_accum = iter_mse_accum = iter_consistency_accum = 0.0 for spec, w, z, _, _, _, _, _ in tqdm(trainloader, desc=f"Epoch {epoch} [Train]", leave=False): B = spec.size(0) # Preprocessing: smoothing flux_smooth = gaussian_filter1d(spec.cpu().detach().numpy(), sigma=5, axis=1) kernel = Gaussian1DKernel(stddev=5, x_size=5) flux_conv_batch = np.array([convolve(f, kernel) for f in flux_smooth]) spec = torch.tensor(flux_conv_batch, device=device, dtype=torch.float32) spec, w, z = spec.to(device).float(), w.to(device).float(), z.to(device).float() rest_spec = to_rest_frame(spec, z, device=device)[0] w_spec = to_rest_frame(w, z, device=device)[0] optimizer.zero_grad() recon = model(spec, z) mse_loss = weighted_mse_loss(recon, rest_spec, w_spec) spec_aug, w_aug, z_aug = desi.augment_spectra(spec, w, z) specf_aug = model.extract_features(spec_aug, z_aug) specf = model.extract_features(spec, z) consistency_loss_value = consistency_loss(specf, specf_aug) loss = mse_loss + consistency_loss_value # loss = mse_loss loss.backward() optimizer.step() train_loss += loss.item() * B num_samples += B iter_count += 1 iter_loss_accum += loss.item() iter_mse_accum += mse_loss.item() iter_consistency_accum += consistency_loss_value.item() if iter_count % log_interval == 0: avg_iter_loss = iter_loss_accum / log_interval avg_iter_mse = iter_mse_accum / log_interval avg_iter_consistency = iter_consistency_accum / log_interval iter_loss_accums.append(avg_iter_loss) iter_mse_accums.append(avg_iter_mse) iter_consistency_accums.append(avg_iter_consistency) # === 替换 print 为 logger.info === logger.info( f"Epoch {epoch} Iter {iter_count}: " f"Avg Total={avg_iter_loss:.6f}, " f"MSE={avg_iter_mse:.6f}, " f"Consistency={avg_iter_consistency:.6f}" ) # Plot loss curves fig_loss, axes = plt.subplots(3, 1, figsize=(10, 12), sharex=True) axes[0].plot(iter_loss_accums, color='blue') axes[0].set_ylabel('Total Loss'); axes[0].grid(True) axes[1].plot(iter_mse_accums, color='green') axes[1].set_ylabel('MSE Loss'); axes[1].grid(True) axes[2].plot(iter_consistency_accums, color='orange') axes[2].set_ylabel('Consistency Loss') axes[2].set_xlabel('Iteration'); axes[2].grid(True) fig_loss.suptitle('Training Loss Components') plt.tight_layout(rect=[0, 0, 1, 0.96]) fig_loss.savefig(os.path.join(fig_dir, "loss_curve_iter.png")) plt.close(fig_loss) # Plot reconstructions for i in range(min(4, B)): plt.figure(figsize=(12, 3)) plt.subplot(2,1,1) plt.plot(rest_spec[i].cpu().numpy(), label="Input", color='blue', linewidth=0.5, alpha=0.7) plt.plot(recon[i].detach().cpu().numpy(), label="Reconstructed", color='red', linewidth=0.5, alpha=0.7) plt.title(f"Epoch {epoch} Iter {iter_count} Sample {i} | z = {z[i].item():.3f}") plt.legend(); plt.grid(True); plt.tight_layout() plt.subplot(2,1,2) plt.plot(w_spec[i].detach().cpu().numpy(), label="w", color='black', linewidth=0.5, alpha=0.7) plt.legend(); plt.grid(True); plt.tight_layout() plt.savefig(os.path.join(fig_dir, f"recon_epoch{epoch}_iter{iter_count}_sample{i}.png")) plt.close() iter_loss_accum = iter_mse_accum = iter_consistency_accum = 0.0 train_loss /= num_samples train_losses.append(train_loss) # Validation model.eval() val_loss, num_samples = 0.0, 0 with torch.no_grad(): for i, (spec, w, z, _, _, _, _, _) in enumerate(tqdm(validloader, desc=f"Epoch {epoch} [Valid]", leave=False)): spec, w, z = spec.to(device).float(), w.to(device).float(), z.to(device).float() rest_spec = to_rest_frame(spec, z, device=device)[0] w_spec = to_rest_frame(w, z, device=device)[0] recon = model(spec, z) loss = weighted_mse_loss(recon, rest_spec, w_spec) val_loss += loss.item() * spec.size(0) num_samples += spec.size(0) if fixed_valid_batch is None and i == 0: fixed_valid_batch = (rest_spec.cpu(), recon.cpu(), z.cpu()) val_loss /= num_samples val_losses.append(val_loss) # === 替换 print 为 logger.info === logger.info(f"Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}") if val_loss < best_val_loss: best_val_loss = val_loss save_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() torch.save(save_dict, save_path) logger.info(f" ✅ New best model saved to {save_path}") # Plot validation reconstructions periodically if epoch % plot_interval == 0 and fixed_valid_batch is not None: spec_fixed, recon_fixed, z_fixed = fixed_valid_batch B_vis = min(4, spec_fixed.size(0)) fig_recon, axes = plt.subplots(B_vis, 1, figsize=(12, 2 * B_vis)) if B_vis == 1: axes = [axes] for i in range(B_vis): ax = axes[i] ax.plot(spec_fixed[i].numpy(), label="Input", color='blue', linewidth=0.5, alpha=0.7) ax.plot(recon_fixed[i].numpy(), label="Reconstructed", color='red', linewidth=0.5, alpha=0.7) ax.set_title(f"z = {z_fixed[i].item():.3f}") ax.legend(); ax.grid(True) plt.tight_layout() plt.savefig(os.path.join(fig_dir, f"recon_epoch_{epoch:03d}.png")) plt.close(fig_recon) logger.info("\n🎉 Training completed.") logger.info(f"Best Validation Loss: {best_val_loss:.6f}") if __name__ == "__main__": device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") desi = desi.DESI() CSV_PATH = '/home/data/multimodalUniverse/DESIspec_make/matched_sources_by_targetid_inner.csv' trainloader = desi.get_data_loader( csv_path=CSV_PATH, which="train", batch_size=64, shuffle=True ) validloader = desi.get_data_loader( csv_path=CSV_PATH, which="test", batch_size=64, shuffle=False ) model = ReconTransformer( orig_length=7781, target_length=9780, patch_size=4, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, use_z_cond=True ) model.to(device) train_model( desi, model=model, trainloader=trainloader, validloader=validloader, device=device, epochs=1000, lr=1e-4, weight_decay=1e-5, save_path="./training_plots_reshift/best_recon_transformer_acc.pth", plot_interval=1, fig_dir="training_plots_reshift")