AstroM3 / ExampleCode /example1 /model /SpecEncoder.py
lvjiameng's picture
Upload 21 files
d24fe95 verified
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinterp1d import interp1d
import torch.optim as optim
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from dataset import desi
from scipy.ndimage import gaussian_filter1d
from astropy.convolution import convolve, Gaussian1DKernel
import numpy as np
import logging
class PatchEmbed1D(nn.Module):
def __init__(self, patch_size=16, d_model=512):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(patch_size, d_model)
def forward(self, x):
B, L = x.shape
pad_len = (math.ceil(L / self.patch_size) * self.patch_size) - L
if pad_len > 0:
x = F.pad(x, (0, pad_len))
x = x.view(B, -1, self.patch_size)
tokens = self.proj(x)
return tokens, L + pad_len
class ReconTransformer(nn.Module):
def __init__(self,
orig_length=7781,
target_length=9780,
patch_size=16,
d_model=512,
nhead=8,
num_layers=6,
dim_feedforward=2048,
dropout=0.1,
use_z_cond=True):
super().__init__()
self.orig_length = orig_length
self.target_length = target_length
self.patch_size = patch_size
self.d_model = d_model
self.use_z_cond = use_z_cond
self.patch_embed = PatchEmbed1D(patch_size=patch_size, d_model=d_model)
max_patches = math.ceil(orig_length / patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, max_patches, d_model) * 0.02)
if use_z_cond:
self.z_proj = nn.Sequential(
nn.Linear(1, d_model),
nn.GELU(),
nn.Linear(d_model, d_model)
)
else:
self.z_proj = None
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation='gelu',
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.recon_head = nn.Linear(d_model, patch_size)
self.norm = nn.LayerNorm(d_model)
self._init_weights()
self.out_patches = math.ceil(self.target_length / patch_size)
def _init_weights(self):
nn.init.xavier_uniform_(self.recon_head.weight)
if self.recon_head.bias is not None:
nn.init.zeros_(self.recon_head.bias)
def forward(self, spec, z=None):
B, L = spec.shape
tokens, padded_len = self.patch_embed(spec)
N_in = tokens.shape[1]
pos = self.pos_embed[:, :N_in, :]
tokens = tokens + pos
if self.use_z_cond and z is not None:
if z.dim() == 1:
z = z.unsqueeze(-1)
z_emb = self.z_proj(z)
tokens = tokens + z_emb.unsqueeze(1)
tokens = self.norm(tokens)
enc_out = self.encoder(tokens)
N_out = self.out_patches
if N_out != N_in:
enc_out = F.interpolate(
enc_out.transpose(1, 2),
size=N_out,
mode="linear",
align_corners=False
).transpose(1, 2)
patches = self.recon_head(enc_out)
recon = patches.reshape(B, N_out * self.patch_size)
recon = recon[:, :self.target_length]
if z is None:
z = torch.zeros(len(recon))
B = recon.size(0)
if z.dim() == 1:
z = z.unsqueeze(1)
return recon
def extract_features(self, spec, z=None, pool="mean"):
B, L = spec.shape
tokens, padded_len = self.patch_embed(spec)
N_in = tokens.shape[1]
pos = self.pos_embed[:, :N_in, :]
tokens = tokens + pos
if self.use_z_cond and z is not None:
if z.dim() == 1:
z = z.unsqueeze(-1)
z_emb = self.z_proj(z)
tokens = tokens + z_emb.unsqueeze(1)
tokens = self.norm(tokens)
enc_out = self.encoder(tokens)
if pool == "mean":
feat = enc_out.mean(dim=1)
elif pool == "cls":
feat = enc_out[:, 0, :]
elif pool == "flatten":
feat = enc_out.reshape(B, -1)
raise ValueError("Unknown pooling method")
return feat
def to_rest_frame(spec_obs, z, wave_obs_min=3600.0, wave_obs_max=9824.0, n_obs=7781, n_rest=9780, device="cuda"):
B = spec_obs.size(0)
wave_obs = torch.linspace(wave_obs_min, wave_obs_max, n_obs,
dtype=torch.float32, device=device)
wave_obs_batch = wave_obs.unsqueeze(0).expand(B, -1)
lmbda_min = wave_obs_min / (1.0 + 0.8)
lmbda_max = wave_obs_max / 1.0
wave_rest = torch.linspace(lmbda_min, lmbda_max, n_rest,
dtype=torch.float32, device=device)
wave_rest_batch = wave_rest.unsqueeze(0).expand(B, -1)
wave_redshifted = wave_rest_batch * (1 + z.unsqueeze(-1))
spec_rest = interp1d(wave_obs_batch, spec_obs, wave_redshifted)
valid = (wave_redshifted >= wave_obs_min) & (wave_redshifted <= wave_obs_max)
spec_rest[~valid] = 0.0
return spec_rest, wave_rest
def weighted_mse_loss(recon, target, weight):
diff = (recon - target) ** 2
weighted = diff * weight
denom = weight.sum()
if denom == 0:
return diff.mean()
return weighted.sum() / denom
def consistency_loss(sf, sf_aug, individual=False):
batch_size, s_size = sf.shape
x = torch.sum((sf_aug - sf)**2/(0.5)**2,dim=1)/s_size
sim_loss = torch.sigmoid(x)-0.5
if individual:
return x, sim_loss
return sim_loss.sum()
def train_model(desi, model, trainloader, validloader, device, epochs=20, lr=1e-4, weight_decay=1e-5,
save_path="best_recon_transformer.pth", plot_interval=1, fig_dir="figures",
log_interval=5000, log_file="training.log"):
# === 新增:配置 logging ===
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file, mode='w'), # 写入文件(覆盖模式)
logging.StreamHandler() # 同时输出到控制台
]
)
logger = logging.getLogger()
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
best_val_loss = float("inf")
fixed_valid_batch = None
iter_loss_accums, iter_mse_accums, iter_consistency_accums = [], [], []
train_losses, val_losses = [], []
for epoch in range(1, epochs + 1):
model.train()
train_loss, num_samples = 0.0, 0
iter_count = 0
iter_loss_accum = iter_mse_accum = iter_consistency_accum = 0.0
for spec, w, z, _, _, _, _, _ in tqdm(trainloader, desc=f"Epoch {epoch} [Train]", leave=False):
B = spec.size(0)
# Preprocessing: smoothing
flux_smooth = gaussian_filter1d(spec.cpu().detach().numpy(), sigma=5, axis=1)
kernel = Gaussian1DKernel(stddev=5, x_size=5)
flux_conv_batch = np.array([convolve(f, kernel) for f in flux_smooth])
spec = torch.tensor(flux_conv_batch, device=device, dtype=torch.float32)
spec, w, z = spec.to(device).float(), w.to(device).float(), z.to(device).float()
rest_spec = to_rest_frame(spec, z, device=device)[0]
w_spec = to_rest_frame(w, z, device=device)[0]
optimizer.zero_grad()
recon = model(spec, z)
mse_loss = weighted_mse_loss(recon, rest_spec, w_spec)
spec_aug, w_aug, z_aug = desi.augment_spectra(spec, w, z)
specf_aug = model.extract_features(spec_aug, z_aug)
specf = model.extract_features(spec, z)
consistency_loss_value = consistency_loss(specf, specf_aug)
loss = mse_loss + consistency_loss_value
# loss = mse_loss
loss.backward()
optimizer.step()
train_loss += loss.item() * B
num_samples += B
iter_count += 1
iter_loss_accum += loss.item()
iter_mse_accum += mse_loss.item()
iter_consistency_accum += consistency_loss_value.item()
if iter_count % log_interval == 0:
avg_iter_loss = iter_loss_accum / log_interval
avg_iter_mse = iter_mse_accum / log_interval
avg_iter_consistency = iter_consistency_accum / log_interval
iter_loss_accums.append(avg_iter_loss)
iter_mse_accums.append(avg_iter_mse)
iter_consistency_accums.append(avg_iter_consistency)
# === 替换 print 为 logger.info ===
logger.info(
f"Epoch {epoch} Iter {iter_count}: "
f"Avg Total={avg_iter_loss:.6f}, "
f"MSE={avg_iter_mse:.6f}, "
f"Consistency={avg_iter_consistency:.6f}"
)
# Plot loss curves
fig_loss, axes = plt.subplots(3, 1, figsize=(10, 12), sharex=True)
axes[0].plot(iter_loss_accums, color='blue')
axes[0].set_ylabel('Total Loss'); axes[0].grid(True)
axes[1].plot(iter_mse_accums, color='green')
axes[1].set_ylabel('MSE Loss'); axes[1].grid(True)
axes[2].plot(iter_consistency_accums, color='orange')
axes[2].set_ylabel('Consistency Loss')
axes[2].set_xlabel('Iteration'); axes[2].grid(True)
fig_loss.suptitle('Training Loss Components')
plt.tight_layout(rect=[0, 0, 1, 0.96])
fig_loss.savefig(os.path.join(fig_dir, "loss_curve_iter.png"))
plt.close(fig_loss)
# Plot reconstructions
for i in range(min(4, B)):
plt.figure(figsize=(12, 3))
plt.subplot(2,1,1)
plt.plot(rest_spec[i].cpu().numpy(), label="Input", color='blue', linewidth=0.5, alpha=0.7)
plt.plot(recon[i].detach().cpu().numpy(), label="Reconstructed", color='red', linewidth=0.5, alpha=0.7)
plt.title(f"Epoch {epoch} Iter {iter_count} Sample {i} | z = {z[i].item():.3f}")
plt.legend(); plt.grid(True); plt.tight_layout()
plt.subplot(2,1,2)
plt.plot(w_spec[i].detach().cpu().numpy(), label="w", color='black', linewidth=0.5, alpha=0.7)
plt.legend(); plt.grid(True); plt.tight_layout()
plt.savefig(os.path.join(fig_dir, f"recon_epoch{epoch}_iter{iter_count}_sample{i}.png"))
plt.close()
iter_loss_accum = iter_mse_accum = iter_consistency_accum = 0.0
train_loss /= num_samples
train_losses.append(train_loss)
# Validation
model.eval()
val_loss, num_samples = 0.0, 0
with torch.no_grad():
for i, (spec, w, z, _, _, _, _, _) in enumerate(tqdm(validloader, desc=f"Epoch {epoch} [Valid]", leave=False)):
spec, w, z = spec.to(device).float(), w.to(device).float(), z.to(device).float()
rest_spec = to_rest_frame(spec, z, device=device)[0]
w_spec = to_rest_frame(w, z, device=device)[0]
recon = model(spec, z)
loss = weighted_mse_loss(recon, rest_spec, w_spec)
val_loss += loss.item() * spec.size(0)
num_samples += spec.size(0)
if fixed_valid_batch is None and i == 0:
fixed_valid_batch = (rest_spec.cpu(), recon.cpu(), z.cpu())
val_loss /= num_samples
val_losses.append(val_loss)
# === 替换 print 为 logger.info ===
logger.info(f"Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
save_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
torch.save(save_dict, save_path)
logger.info(f" ✅ New best model saved to {save_path}")
# Plot validation reconstructions periodically
if epoch % plot_interval == 0 and fixed_valid_batch is not None:
spec_fixed, recon_fixed, z_fixed = fixed_valid_batch
B_vis = min(4, spec_fixed.size(0))
fig_recon, axes = plt.subplots(B_vis, 1, figsize=(12, 2 * B_vis))
if B_vis == 1: axes = [axes]
for i in range(B_vis):
ax = axes[i]
ax.plot(spec_fixed[i].numpy(), label="Input", color='blue', linewidth=0.5, alpha=0.7)
ax.plot(recon_fixed[i].numpy(), label="Reconstructed", color='red', linewidth=0.5, alpha=0.7)
ax.set_title(f"z = {z_fixed[i].item():.3f}")
ax.legend(); ax.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, f"recon_epoch_{epoch:03d}.png"))
plt.close(fig_recon)
logger.info("\n🎉 Training completed.")
logger.info(f"Best Validation Loss: {best_val_loss:.6f}")
if __name__ == "__main__":
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
desi = desi.DESI()
CSV_PATH = '/home/data/multimodalUniverse/DESIspec_make/matched_sources_by_targetid_inner.csv'
trainloader = desi.get_data_loader(
csv_path=CSV_PATH,
which="train",
batch_size=64,
shuffle=True
)
validloader = desi.get_data_loader(
csv_path=CSV_PATH,
which="test",
batch_size=64,
shuffle=False
)
model = ReconTransformer(
orig_length=7781,
target_length=9780,
patch_size=4,
d_model=512,
nhead=8,
num_layers=6,
dim_feedforward=2048,
dropout=0.1,
use_z_cond=True
)
model.to(device)
train_model( desi, model=model,
trainloader=trainloader,
validloader=validloader,
device=device,
epochs=1000,
lr=1e-4,
weight_decay=1e-5,
save_path="./training_plots_reshift/best_recon_transformer_acc.pth",
plot_interval=1,
fig_dir="training_plots_reshift")