peptide-diffusion-v2 / src /diffusion.py
Gustav-Proxi's picture
Upload src/diffusion.py with huggingface_hub
fcf4f5a verified
"""
src/diffusion.py
V-1 (Base): TransformerDenoiser
V-2 (NOVEL#1): Entropy-adaptive mass gate + gate_confidence export
V-3 (NOVEL#4): Spectral noise augmentation during training
"""
import math
import os
import glob
import sys
import random
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
sys.path.insert(0, os.path.dirname(__file__))
from baseline import Encoder
from preprocessing import (
load_labeled_spectra, preprocess_spectrum, encode_peptide,
VOCAB, CHAR_TO_IDX,
)
# ── Constants ──────────────────────────────────────────────────────────────────
VOCAB_SIZE = 24 # PAD=0, SOS=1, EOS=2, AAs=3..22, MASK=23
MASK_TOK = 23 # absorbing state β€” never predicted, only used in q_sample
SEQ_LEN = 32 # 30 residues + SOS + EOS
T_STEPS = 200
BETA_START = 1e-3
BETA_END = 2e-2
PROTON_MASS = 1.007276
WATER_MASS = 18.010565
N_PEAKS = 200 # top-K peaks kept per spectrum
MAX_MZ = 2000.0 # m/z normalisation ceiling
BY_SIGMA = 0.1 # Da β€” Gaussian half-width for b/y ion proximity bias
# Entropy-adaptive gate hyperparameters (NOVEL #1)
GATE_BASE_TOL = 0.02 # Da β€” tight tolerance when model is confident
GATE_ALPHA = 2.0 # scales how much entropy relaxes the gate
# Spectral noise augmentation hyperparameters (NOVEL #4)
AUG_PROB = 0.4 # probability of augmenting a training spectrum
AUG_NOISE_STD = 0.05 # fraction of per-spectrum std to add as noise
# Mass consistency loss weight (Option A fix)
# Uses relative L1 (|pred-true|/true), so values are in [0,1] range.
# 0.1 weight means mass term contributes ~10% as much as CE when error is ~10%.
MASS_LOSS_WEIGHT = 0.1
# VOCAB = "ACDEFGHIKLMNPQRSTVWY" β†’ token[i+3] = VOCAB[i]
_MONO_MASSES = [
71.03711, # A token 3
103.00919, # C token 4
115.02694, # D token 5
129.04259, # E token 6
147.06841, # F token 7
57.02146, # G token 8
137.05891, # H token 9
113.08406, # I token 10
128.09496, # K token 11
113.08406, # L token 12
131.04049, # M token 13
114.04293, # N token 14
97.05276, # P token 15
128.05858, # Q token 16
156.10111, # R token 17
87.03203, # S token 18
101.04768, # T token 19
99.06841, # V token 20
186.07931, # W token 21
163.06333, # Y token 22
]
RESIDUE_MASS = torch.zeros(VOCAB_SIZE)
for _i, _m in enumerate(_MONO_MASSES):
RESIDUE_MASS[_i + 3] = _m
IDX_TO_CHAR = {i + 3: c for i, c in enumerate(VOCAB)}
def extract_top_peaks(mz_arr, int_arr, n_peaks: int = N_PEAKS,
max_mz: float = MAX_MZ) -> np.ndarray:
"""Return (n_peaks, 2) array of top-K peaks sorted by m/z, normalised.
Column 0: m/z / max_mz ∈ [0, 1]. Column 1: intensity / max_intensity ∈ [0, 1].
Rows beyond the actual peak count are zero-padded."""
mz = np.asarray(mz_arr, dtype=np.float32)
ity = np.asarray(int_arr, dtype=np.float32)
valid = (mz > 50) & (mz < max_mz)
mz, ity = mz[valid], ity[valid]
if len(ity) == 0:
return np.zeros((n_peaks, 2), dtype=np.float32)
ity = ity / (ity.max() + 1e-9)
if len(ity) > n_peaks:
top = np.argpartition(ity, -n_peaks)[-n_peaks:]
else:
top = np.arange(len(ity))
mz_s, ity_s = mz[top], ity[top]
order = np.argsort(mz_s)
mz_s, ity_s = mz_s[order], ity_s[order]
out = np.zeros((n_peaks, 2), dtype=np.float32)
n = len(mz_s)
out[:n, 0] = mz_s / max_mz
out[:n, 1] = ity_s
return out
# ── Diffusion Schedule (cosine, D3PM-style) ────────────────────────────────────
# Cosine schedule absorbs tokens more uniformly than linear β€” avoids degenerate
# distributions at very small or very large t.
_t_cos = torch.arange(T_STEPS + 1, dtype=torch.float) / T_STEPS
_f = torch.cos((_t_cos + 0.008) / 1.008 * math.pi / 2) ** 2
_alpha_bars = (_f / _f[0]).clamp(min=1e-5)[1:] # (T,) ᾱ_t ∈ (0,1]
_betas = (1 - _alpha_bars[1:] / _alpha_bars[:-1]).clamp(0, 0.999)
_betas = torch.cat([torch.tensor([1 - _alpha_bars[0]]), _betas])
# Accelerated inference anchors: 20 evenly-spaced steps (avoids error accumulation)
_INFER_STEPS = list(range(T_STEPS - 1, -1, -(T_STEPS // 20)))
if _INFER_STEPS[-1] != 0:
_INFER_STEPS.append(0)
def set_seed(seed: int):
"""Set all RNG seeds for reproducibility across training runs."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def q_sample(x0: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Forward process: keep x0 token with prob αΎ±_t, else absorb to MASK_TOK.
Absorbing diffusion aligns with mask-predict decoding: re-corrupting uncertain
positions at inference with MASK_TOK now matches the training distribution."""
abar = _alpha_bars[t.cpu()].to(x0.device).unsqueeze(1)
keep = torch.bernoulli(abar.expand_as(x0.float())).bool()
mask = torch.full_like(x0, MASK_TOK)
return torch.where(keep, x0, mask)
# ── Peak-Level Augmentation ────────────────────────────────────────────────────
def augment_peaks(peaks: torch.Tensor, p: float = AUG_PROB,
noise_frac: float = AUG_NOISE_STD) -> torch.Tensor:
"""
Peak-level augmentation for the PeakEncoder (replaces old augment_spectrum).
With probability p per spectrum:
- Add Gaussian noise to intensities (noise_frac fraction of max intensity).
- Randomly zero out 10% of peaks (simulates missing ions).
peaks: (B, K, 2) β€” column 0 = m/z/MAX_MZ, column 1 = intensity.
Augmentation is skipped at inference (torch.is_grad_enabled() == False).
"""
if not torch.is_grad_enabled():
return peaks
B, K, _ = peaks.shape
aug_mask = torch.rand(B, device=peaks.device) < p # (B,) spectra to augment
if not aug_mask.any():
return peaks
out = peaks.clone()
# Intensity noise
int_noise = torch.randn(B, K, device=peaks.device) * noise_frac
out[:, :, 1] = (out[:, :, 1] + int_noise).clamp(0.0, 1.0)
# m/z jitter: Β±0.01 normalised β‰ˆ Β±20 Da β€” simulates instrument calibration variance
mz_jitter = torch.randn(B, K, device=peaks.device) * 0.005
present = out[:, :, 0] > 0 # don't jitter padded peaks
out[:, :, 0] = (out[:, :, 0] + mz_jitter * present).clamp(0.0, 1.0)
# Random peak dropout (10%)
drop = torch.rand(B, K, device=peaks.device) < 0.10
out[:, :, 0] = out[:, :, 0].masked_fill(drop, 0.0)
out[:, :, 1] = out[:, :, 1].masked_fill(drop, 0.0)
aug_mask_2d = aug_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1)
return torch.where(aug_mask_2d, out, peaks)
# ── NOVEL #1: Entropy-Adaptive Mass Gate ──────────────────────────────────────
def entropy_adaptive_gate(logits: torch.Tensor,
precursor_masses: torch.Tensor,
base_tol: float = GATE_BASE_TOL,
alpha: float = GATE_ALPHA) -> tuple:
"""
Per-position tolerance = base_tol * (1 + alpha * H_t(i) / log(23)).
High entropy (uncertain) β†’ relax gate. Low entropy (confident) β†’ keep tight.
logits: (B, L, V) β€” modified in-place clone
precursor_masses: (B,) β€” neutral peptide mass in Da
Returns: (gated_logits, gate_confidence)
gate_confidence: (B, L) β€” 1 = tight gate held, 0 = gate relaxed
"""
B, L, V = logits.shape
mass_lut = RESIDUE_MASS.to(logits.device) # (V,)
out = logits.clone()
# Skip gate entirely for spectra with unknown precursor mass
valid = precursor_masses > 1.0 # (B,)
# Per-position entropy β†’ per-position tolerance
with torch.no_grad():
probs = F.softmax(logits, dim=-1) # (B, L, V)
entropy = -(probs * (probs + 1e-9).log()).sum(-1) # (B, L)
tol = base_tol * (1 + alpha * entropy / math.log(V)) # (B, L)
gate_conf = torch.ones(B, L, device=logits.device) # 1 = tight
# Best-guess token at each position for the "other positions" mass estimate
best = logits.argmax(-1) # (B, L)
best_m = mass_lut[best] # (B, L)
is_aa = (best >= 3).float() # (B, L)
for pos in range(L):
# Mass from all OTHER real-AA positions
other = (best_m * is_aa).sum(1) \
- best_m[:, pos] * is_aa[:, pos] \
+ WATER_MASS # (B,)
# Candidate total mass for each token at this position
cand = other.unsqueeze(1) + mass_lut.unsqueeze(0) # (B, V)
pos_tol = tol[:, pos].unsqueeze(1) # (B, 1)
feasible = (cand - precursor_masses.unsqueeze(1)).abs() < pos_tol
feasible[:, :3] = True # PAD/SOS/EOS always pass
feasible[~valid] = True # skip gate for unknown-mass spectra
needs_relax = ~feasible.any(dim=-1) # (B,)
gate_conf[:, pos] = (~needs_relax).float()
# If even relaxed gate (0.1 Da) zeros everything, leave logits unchanged
if needs_relax.any():
relax_feasible = (cand - precursor_masses.unsqueeze(1)).abs() < 0.1
relax_feasible[:, :3] = True
relax_feasible[~valid] = True
final_mask = torch.where(
needs_relax.unsqueeze(1), relax_feasible, feasible
)
else:
final_mask = feasible
out[:, pos].masked_fill_(~final_mask, float('-inf'))
return out, gate_conf
# ── Model Components ───────────────────────────────────────────────────────────
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
half = dim // 2
freq = torch.exp(-math.log(10000) * torch.arange(half) / half)
self.register_buffer('freq', freq)
self.proj = nn.Linear(dim, dim)
def forward(self, t: torch.Tensor) -> torch.Tensor:
t = t.float().unsqueeze(1)
x = t * self.freq.unsqueeze(0)
x = torch.cat([x.sin(), x.cos()], dim=-1)
return self.proj(x)
class PeakEncoder(nn.Module):
"""
Transformer encoder over the top-K (m/z, intensity) peak pairs.
Replaces the binned-spectrum CNN encoder with direct peak-level attention,
which is how InstaNovo and all top de-novo models encode spectra.
Prepends a precursor-mass token so the model always knows the target mass.
"""
def __init__(self, d_model: int = 512, n_peaks: int = N_PEAKS,
nhead: int = 8, n_layers: int = 6):
super().__init__()
self.mz_emb = SinusoidalEmbedding(d_model) # embed m/z as float
self.int_proj = nn.Linear(1, d_model)
self.mass_emb = SinusoidalEmbedding(d_model) # embed precursor mass
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
batch_first=True, norm_first=True, dropout=0.1)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
def forward(self, peaks: torch.Tensor,
masses: torch.Tensor) -> tuple:
"""
peaks: (B, K, 2) β€” (mz/MAX_MZ, intensity)
masses: (B,) β€” neutral precursor mass in Da
Returns ((B, K+1, d_model), (B, K+1) bool pad mask)
The pad mask is True for positions that should be ignored (padded peaks).
"""
B, K, _ = peaks.shape
mz_scaled = peaks[:, :, 0].reshape(-1) * MAX_MZ # (B*K,)
mz_tok = self.mz_emb(mz_scaled).reshape(B, K, -1) # (B, K, d)
int_tok = self.int_proj(peaks[:, :, 1:]) # (B, K, d)
peak_tok = mz_tok + int_tok # (B, K, d)
mass_tok = self.mass_emb(masses).unsqueeze(1) # (B, 1, d)
x = torch.cat([mass_tok, peak_tok], dim=1) # (B, K+1, d)
# Mask zero-padded peaks (m/z == 0 means no peak)
pad = torch.zeros(B, K + 1, dtype=torch.bool, device=peaks.device)
pad[:, 1:] = (peaks[:, :, 0] == 0)
out = self.encoder(x, src_key_padding_mask=pad) # (B, K+1, d)
return out, pad
def compute_by_pair_bias(xt: torch.Tensor, peaks: torch.Tensor,
precursor_masses: torch.Tensor,
sigma: float = BY_SIGMA) -> torch.Tensor:
"""
AlphaFold3-inspired B/Y ion pair bias.
For each (sequence position i, peak j) pair, compute the Gaussian proximity
of the theoretical b/y ion at position i to peak j's m/z, weighted by
peak intensity. Committed (non-MASK) positions contribute signal;
masked positions contribute zero β€” so the model learns to use committed
positions to guide prediction of uncertain ones.
Returns (B, L, K+1) β€” first column is the mass-token (always 0 bias).
"""
B, L = xt.shape
K = peaks.shape[1]
device = xt.device
mass_lut = RESIDUE_MASS.to(device) # (V,)
# Token masses; MASK and special tokens β†’ 0
is_aa = ((xt >= 3) & (xt < MASK_TOK)).float() # (B, L)
tok_mass = mass_lut[xt.clamp(0, VOCAB_SIZE - 1)] * is_aa # (B, L)
# Cumulative residue mass from position 0 β†’ b-ion series
cum_mass = torch.cumsum(tok_mass, dim=1) # (B, L)
b_ions = cum_mass + PROTON_MASS # (B, L)
# y-ion series: precursor_mass βˆ’ cum_mass (approximate)
pm = precursor_masses.unsqueeze(1) # (B, 1)
valid = (pm > 1.0).float() # skip unknown mass
y_ions = (pm - cum_mass + WATER_MASS + PROTON_MASS) * valid # (B, L)
# Observed peak m/z, unnormalised
mz_obs = peaks[:, :, 0] * MAX_MZ # (B, K)
int_obs = peaks[:, :, 1] # (B, K)
# Pairwise Gaussian proximity: (B, L, K)
b_diff = b_ions.unsqueeze(2) - mz_obs.unsqueeze(1) # (B, L, K)
y_diff = y_ions.unsqueeze(2) - mz_obs.unsqueeze(1) # (B, L, K)
prox = (torch.exp(-b_diff**2 / (2 * sigma**2)) +
torch.exp(-y_diff**2 / (2 * sigma**2))) # (B, L, K)
prox = prox * int_obs.unsqueeze(1) # weight by intensity
prox = prox * is_aa.unsqueeze(2) # zero out masked pos
# Prepend zero column for the mass token: (B, L, K+1)
return F.pad(prox, (1, 0))
class BYBiasedDecoderLayer(nn.Module):
"""
Pre-norm Transformer decoder layer with B/Y ion pair bias injected
into the cross-attention logits (AlphaFold3-style pair bias).
A learnable scalar `bias_scale` (initialised to 0) lets the model
start as a standard decoder and gradually lean on the ion signal.
"""
def __init__(self, d_model: int, nhead: int, dim_ff: int):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, dim_ff), nn.GELU(), nn.Linear(dim_ff, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.nhead = nhead
# Starts at 0 so model is identical to unbiased decoder at init
self.bias_scale = nn.Parameter(torch.zeros(1))
def forward(self, tgt: torch.Tensor, memory: torch.Tensor,
by_bias: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None) -> torch.Tensor:
B, L, _ = tgt.shape
# Self-attention (pre-norm)
h = self.norm1(tgt)
h, _ = self.self_attn(h, h, h)
tgt = tgt + h
# Cross-attention with optional B/Y pair bias and padding mask
h = self.norm2(tgt)
if by_bias is not None:
# by_bias: (B, L, K+1) β†’ (B*nhead, L, K+1) float additive mask
bias = (by_bias.unsqueeze(1)
.expand(-1, self.nhead, -1, -1)
.reshape(B * self.nhead, L, -1))
bias = self.bias_scale * bias
# Fuse padding mask into attn_mask as -inf (avoids bool/float mismatch)
if memory_key_padding_mask is not None:
pad_f = memory_key_padding_mask.float().masked_fill(
memory_key_padding_mask, float('-inf')) # (B, K+1)
pad_f = (pad_f.unsqueeze(1).unsqueeze(1)
.expand(-1, self.nhead, L, -1)
.reshape(B * self.nhead, L, -1)) # (B*h, L, K+1)
bias = bias + pad_f
h, _ = self.cross_attn(h, memory, memory, attn_mask=bias)
else:
h, _ = self.cross_attn(h, memory, memory,
key_padding_mask=memory_key_padding_mask)
tgt = tgt + h
# Feed-forward (pre-norm)
h = self.norm3(tgt)
tgt = tgt + self.ff(h)
return tgt
class TransformerDenoiser(nn.Module):
"""
Bidirectional denoiser with peak-level cross-attention.
memory: (B, K+1, d) sequence of peak embeddings from PeakEncoder.
B/Y ion pair bias (AF3-inspired) injected into every cross-attention layer.
Self-conditioning: optionally accepts x0-hat from the previous denoising step
as an additional per-position embedding (MDLM-style, Sahoo 2024).
"""
def __init__(self, vocab_size: int = VOCAB_SIZE, d_model: int = 512,
nhead: int = 8, dim_ff: int = 2048,
num_layers: int = 6, seq_len: int = SEQ_LEN):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_emb = nn.Embedding(seq_len, d_model)
self.time_emb = SinusoidalEmbedding(d_model)
# Self-conditioning embedding β€” init to zero so it has no effect at epoch 0
self.self_cond_emb = nn.Embedding(vocab_size, d_model)
nn.init.zeros_(self.self_cond_emb.weight)
self.layers = nn.ModuleList([
BYBiasedDecoderLayer(d_model, nhead, dim_ff)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.out = nn.Linear(d_model, vocab_size)
def forward(self, xt: torch.Tensor, t: torch.Tensor,
memory: torch.Tensor,
peaks: torch.Tensor | None = None,
precursor_masses: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
self_cond: torch.Tensor | None = None) -> torch.Tensor:
B, L = xt.shape
pos = torch.arange(L, device=xt.device).unsqueeze(0).expand(B, -1)
x = self.token_emb(xt) + self.pos_emb(pos)
x = x + self.time_emb(t).unsqueeze(1)
# Self-conditioning: add x0-hat embedding from previous step
if self_cond is not None:
x = x + self.self_cond_emb(self_cond)
by_bias = (compute_by_pair_bias(xt, peaks, precursor_masses)
if peaks is not None and precursor_masses is not None
else None)
for layer in self.layers:
x = layer(x, memory, by_bias=by_bias,
memory_key_padding_mask=memory_key_padding_mask)
return self.out(self.norm(x)) # (B, L, V)
# ── Dataset ────────────────────────────────────────────────────────────────────
class DiffusionDataset(Dataset):
"""Stores (peaks, sequence, precursor_mass) triples for peak-level training."""
def __init__(self, peaks, y, precursor_masses):
self.peaks = torch.tensor(peaks, dtype=torch.float32) # (N, K, 2)
self.y = torch.tensor(y, dtype=torch.long)
self.masses = torch.tensor(precursor_masses, dtype=torch.float32)
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
return self.peaks[idx], self.y[idx], self.masses[idx]
def build_diffusion_dataset(mzml_path: str, xlsx_path: str,
max_spectra: int = 5000,
n_peaks: int = N_PEAKS,
return_raw: bool = False):
"""Returns peaks (N, K, 2), y (N, 32), neutral_masses (N,).
If return_raw=True also returns raw_peaks: list of (mz_arr, int_arr) for SGIR."""
spectra = load_labeled_spectra(mzml_path, xlsx_path, max_spectra)
peaks_list, y, masses, raw_peaks = [], [], [], []
for s in spectra:
peaks_list.append(extract_top_peaks(s['mz'], s['intensity'], n_peaks))
y.append(encode_peptide(s['peptide']))
prec_mz = s.get('precursor_mz') or 0.0
charge = s.get('charge') or 0
neutral = float(charge) * (float(prec_mz) - PROTON_MASS) if charge else 0.0
masses.append(neutral)
if return_raw:
raw_peaks.append((np.asarray(s['mz'], dtype=np.float32),
np.asarray(s['intensity'], dtype=np.float32)))
out = (np.array(peaks_list, dtype=np.float32),
np.array(y),
np.array(masses, dtype=np.float32))
return out + (raw_peaks,) if return_raw else out
# ── Training ───────────────────────────────────────────────────────────────────
def train_diffusion(mzml_paths, xlsx_paths, checkpoint_dir='checkpoints',
epochs=50, batch_size=32, lr=1e-3, device=None, seed=42):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(seed)
print(f"Device: {device} | Seed: {seed}")
os.makedirs(checkpoint_dir, exist_ok=True)
all_peaks, ys, ms = [], [], []
for mzml, xlsx in zip(mzml_paths, xlsx_paths):
pk, y, m = build_diffusion_dataset(mzml, xlsx)
all_peaks.append(pk); ys.append(y); ms.append(m)
peaks = np.concatenate(all_peaks)
y = np.concatenate(ys)
masses = np.concatenate(ms)
print(f"Total spectra: {len(peaks)}")
N = len(peaks)
rng = np.random.default_rng(42)
idx = rng.permutation(N)
n_tr = int(0.70 * N); n_va = int(0.15 * N)
tr, va = idx[:n_tr], idx[n_tr:n_tr + n_va]
te = idx[n_tr + n_va:]
train_dl = DataLoader(DiffusionDataset(peaks[tr], y[tr], masses[tr]),
batch_size=batch_size, shuffle=True, drop_last=True)
val_dl = DataLoader(DiffusionDataset(peaks[va], y[va], masses[va]),
batch_size=batch_size, drop_last=True)
encoder = PeakEncoder().to(device)
denoiser = TransformerDenoiser().to(device)
params = list(encoder.parameters()) + list(denoiser.parameters())
opt = optim.AdamW(params, lr=lr, weight_decay=1e-2)
warmup_ep = max(5, epochs // 20) # 5% warmup
def _lr_lambda(ep):
if ep < warmup_ep:
return (ep + 1) / warmup_ep
t = (ep - warmup_ep) / max(epochs - warmup_ep, 1)
return 0.5 * (1.0 + math.cos(math.pi * t)) * (1 - 1e-5 / lr) + 1e-5 / lr
scheduler = optim.lr_scheduler.LambdaLR(opt, _lr_lambda)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=0)
best_val = float('inf')
for epoch in range(1, epochs + 1):
encoder.train(); denoiser.train()
tr_loss = 0.0
for pks, seq, mass in train_dl:
pks, seq, mass = pks.to(device), seq.to(device), mass.to(device)
B = seq.shape[0]
t = torch.randint(0, T_STEPS, (B,), device=device)
xt = q_sample(seq, t)
pks_aug = augment_peaks(pks)
memory, pad_mask = encoder(pks_aug, mass)
# Self-conditioning (MDLM): 50% of steps, run a no-grad forward pass
# first to get x0-hat, then pass it as self_cond for the actual step.
self_cond = None
if torch.rand(1).item() < 0.5:
with torch.no_grad():
sc_logits = denoiser(xt, t, memory, peaks=pks_aug,
precursor_masses=mass,
memory_key_padding_mask=pad_mask)
self_cond = sc_logits.argmax(-1).detach()
logits = denoiser(xt, t, memory, peaks=pks_aug, precursor_masses=mass,
memory_key_padding_mask=pad_mask, self_cond=self_cond)
ce_loss = criterion(logits.reshape(-1, VOCAB_SIZE), seq.reshape(-1))
# Mass consistency loss (relative L1 on expected sequence mass)
valid_m = mass > 1.0
if valid_m.any():
mass_lut = RESIDUE_MASS.to(device)
probs = F.softmax(logits, dim=-1)
exp_mass = (probs * mass_lut).sum(-1)
is_aa_soft = 1.0 - probs[:, :, :3].sum(-1)
pred_mass = (exp_mass * is_aa_soft).sum(-1) + WATER_MASS
mass_loss = ((pred_mass[valid_m] - mass[valid_m]).abs()
/ mass[valid_m]).mean()
loss = ce_loss + MASS_LOSS_WEIGHT * mass_loss
else:
loss = ce_loss
opt.zero_grad(); loss.backward()
nn.utils.clip_grad_norm_(params, 1.0)
opt.step()
tr_loss += ce_loss.item() * B
encoder.eval(); denoiser.eval()
va_loss = 0.0
with torch.no_grad():
for pks, seq, mass in val_dl:
pks, seq, mass = pks.to(device), seq.to(device), mass.to(device)
t = torch.randint(0, T_STEPS, (seq.shape[0],), device=device)
xt = q_sample(seq, t)
memory, pad_mask = encoder(pks, mass)
logits = denoiser(xt, t, memory, peaks=pks, precursor_masses=mass,
memory_key_padding_mask=pad_mask)
va_loss += criterion(logits.reshape(-1, VOCAB_SIZE),
seq.reshape(-1)).item() * pks.shape[0]
tr_avg = tr_loss / len(train_dl.dataset)
va_avg = va_loss / len(val_dl.dataset)
print(f"Epoch {epoch:3d} | train {tr_avg:.4f} | val {va_avg:.4f}")
if epoch % 10 == 0:
path = os.path.join(checkpoint_dir, f'diffusion_ckpt_{epoch}.pt')
torch.save({'epoch': epoch, 'encoder': encoder.state_dict(),
'denoiser': denoiser.state_dict()}, path)
print(f" Saved {path}")
scheduler.step()
if va_avg < best_val:
best_val = va_avg
torch.save({'epoch': epoch, 'encoder': encoder.state_dict(),
'denoiser': denoiser.state_dict()},
os.path.join(checkpoint_dir, 'diffusion_best.pt'))
# Always save the last epoch β€” AA recall peaks here, not at best val CE loss
torch.save({'epoch': epochs, 'encoder': encoder.state_dict(),
'denoiser': denoiser.state_dict()},
os.path.join(checkpoint_dir, 'diffusion_final.pt'))
print(f" Saved diffusion_final.pt (epoch {epochs})")
return encoder, denoiser, (peaks[te], y[te], masses[te])
# ── Inference ──────────────────────────────────────────────────────────────────
def _seq_mass(seq: str) -> float:
"""Monoisotopic neutral mass of a peptide string (residues + H2O)."""
return sum(_MONO_MASSES[VOCAB.index(c)] for c in seq if c in VOCAB) + WATER_MASS
def mass_correct_sequence(seq: str, precursor_mass: float,
tol: float = 0.05) -> str:
"""
Option B: post-hoc mass correction.
If the predicted sequence mass is outside tol Da of precursor_mass,
try all single-position amino acid swaps and keep the one that minimises
the mass error. Returns seq unchanged when precursor_mass is unknown (≀1 Da)
or already within tolerance.
"""
if precursor_mass <= 1.0 or not seq:
return seq
current_delta = abs(_seq_mass(seq) - precursor_mass)
if current_delta <= tol:
return seq
best_seq, best_delta = seq, current_delta
for pos in range(len(seq)):
for aa in VOCAB:
if aa == seq[pos]:
continue
candidate = seq[:pos] + aa + seq[pos + 1:]
delta = abs(_seq_mass(candidate) - precursor_mass)
if delta < best_delta:
best_delta, best_seq = delta, candidate
# Only apply if the correction actually achieves tolerance β€” a swap that
# merely reduces the error without reaching tolerance swaps a correct AA
# for a wrong one more often than not.
return best_seq if best_delta <= tol else seq
# ── NOVEL #8: Mass-Constrained Beam Search ────────────────────────────────────
_MIN_RESIDUE_MASS = min(_MONO_MASSES) # G = 57.02 Da
_MAX_RESIDUE_MASS = max(_MONO_MASSES) # W = 186.08 Da
def mass_constrained_beam_search(logits_t0: torch.Tensor,
precursor_masses: torch.Tensor,
beam_width: int = 20,
tol: float = 0.1) -> list:
"""
NOVEL #8: replace independent argmax with left-to-right beam search.
At each position, expand beams using top-k token log-probs and prune
branches whose remaining mass budget is infeasible given precursor_mass.
logits_t0: (B, L, V)
precursor_masses: (B,)
Returns: list[B] of str β€” best sequence per spectrum.
"""
B, L, V = logits_t0.shape
log_probs = F.log_softmax(logits_t0, dim=-1) # (B, L, V)
mass_lut = RESIDUE_MASS.to(logits_t0.device) # (V,)
results = []
for b in range(B):
pm = float(precursor_masses[b].cpu())
lp = log_probs[b] # (L, V)
# beam: list of (score, [token_ids], accumulated_aa_mass)
beams = [(0.0, [], 0.0)]
for pos in range(L):
top_scores, top_toks = lp[pos].topk(beam_width)
new_beams = []
for score, toks, acc_mass in beams:
for s, tok in zip(top_scores.tolist(), top_toks.tolist()):
new_score = score + s
tok_mass = float(mass_lut[tok].cpu())
new_mass = acc_mass + tok_mass if tok >= 3 else acc_mass
# Remaining positions: [pos+1 .. L-1]
remaining = L - pos - 1
# Feasibility: can any filling of remaining positions hit pm?
lo = new_mass + WATER_MASS + remaining * _MIN_RESIDUE_MASS
hi = new_mass + WATER_MASS + remaining * _MAX_RESIDUE_MASS
# If tok is EOS, we're done β€” check mass now
if tok == 2:
delta = abs(new_mass + WATER_MASS - pm)
if pm <= 1.0 or delta <= tol * 3:
new_beams.append((new_score, toks + [tok], new_mass))
continue
# Prune if no remaining filling can reach pm (unless mass unknown)
if pm > 1.0 and (lo > pm + tol or hi < pm - tol):
continue
new_beams.append((new_score, toks + [tok], new_mass))
if not new_beams:
# Fallback: keep existing beams without pruning
new_beams = [(sc + float(lp[pos].max().cpu()),
tk + [int(lp[pos].argmax().cpu())], am)
for sc, tk, am in beams]
# Keep top beam_width by score
new_beams.sort(key=lambda x: x[0], reverse=True)
beams = new_beams[:beam_width]
best_toks = beams[0][1]
results.append(best_toks)
return results
# ── NOVEL #11: Mask-Predict Iterative Decoding ────────────────────────────────
@torch.no_grad()
def mask_predict_decode(logits_t0: torch.Tensor,
context: torch.Tensor,
denoiser: nn.Module,
n_iter: int = 8,
peaks: torch.Tensor | None = None,
precursor_masses: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
self_cond: torch.Tensor | None = None) -> list:
"""
NOVEL #11: Mask-Predict iterative decoding for the bidirectional denoiser.
Correct counterpart to left-to-right beam search: keeps highest-confidence
positions fixed and re-corrupts the uncertain ones, letting the bidirectional
model refine them conditioned on the already-committed positions.
Grounded in Ghazvininejad et al. 2019 (Mask-Predict) and LLaDA 2025.
No training changes needed β€” works on existing checkpoints.
"""
B, L, V = logits_t0.shape
device = logits_t0.device
probs = F.softmax(logits_t0, dim=-1)
conf = probs.max(dim=-1).values # (B, L)
x0 = logits_t0.argmax(dim=-1) # (B, L)
t_zero = torch.zeros(B, dtype=torch.long, device=device)
# Linear schedule: commit progressively more positions each round
n_uncertain_schedule = [
int(L * (1.0 - (i + 1) / n_iter)) for i in range(n_iter)
] # e.g. n_iter=4, L=32 β†’ [24, 16, 8, 0]
for step in range(n_iter):
n_uncertain = n_uncertain_schedule[step]
if n_uncertain == 0:
break
# Mask least-confident positions
rank = conf.argsort(dim=-1, descending=False) # lowest conf first
uncertain_mask = torch.zeros(B, L, dtype=torch.bool, device=device)
uncertain_mask.scatter_(1, rank[:, :n_uncertain], True)
# Corrupt uncertain positions with MASK_TOK β€” matches absorbing q_sample
mask_fill = torch.full((B, L), MASK_TOK, dtype=torch.long, device=device)
xt_new = torch.where(uncertain_mask, mask_fill, x0)
# Re-denoise at t=0 β€” bidirectional model conditions on fixed positions
logits_new = denoiser(xt_new, t_zero, context,
peaks=peaks, precursor_masses=precursor_masses,
memory_key_padding_mask=memory_key_padding_mask,
self_cond=self_cond)
logits_new[..., MASK_TOK] = float('-inf') # never output MASK
probs_new = F.softmax(logits_new, dim=-1)
# Update only uncertain positions; carry x0 forward as self_cond
x0 = torch.where(uncertain_mask, logits_new.argmax(dim=-1), x0)
conf = torch.where(uncertain_mask, probs_new.max(dim=-1).values, conf)
self_cond = x0
return [decode_tokens(x0[b].cpu().tolist()) for b in range(B)]
# ── NOVEL #9: Spectrally-Grounded Iterative Refinement ────────────────────────
def _compute_by_ions(seq: str) -> tuple:
"""Compute theoretical b and y ion m/z values (singly charged, +1 proton)."""
masses = [_MONO_MASSES[VOCAB.index(c)] for c in seq if c in VOCAB]
if not masses:
return np.array([]), np.array([])
b_ions = np.cumsum(masses) + PROTON_MASS
y_ions = np.cumsum(masses[::-1]) + WATER_MASS + PROTON_MASS
return b_ions, y_ions
def _peak_support(theoretical_mz: np.ndarray, obs_mz: np.ndarray,
obs_int: np.ndarray, tol: float = 0.02) -> float:
"""Intensity-weighted fraction of theoretical ions matched in observed spectrum."""
if len(theoretical_mz) == 0 or len(obs_mz) == 0:
return 0.0
total = 0.0
for tmz in theoretical_mz:
delta = np.abs(obs_mz - tmz)
idx = delta.argmin()
if delta[idx] <= tol:
total += float(obs_int[idx])
denom = obs_int.sum()
return total / denom if denom > 0 else 0.0
def spectrally_grounded_refine(seq: str, obs_mz: np.ndarray, obs_int: np.ndarray,
precursor_mass: float, t0_log_probs: torch.Tensor,
max_iter: int = 5, ion_tol: float = 0.02) -> str:
"""
NOVEL #9: Spectrally-Grounded Iterative Refinement (SGIR).
For each iteration:
1. Compute b/y ion support per position.
2. Find weakest-supported position.
3. Try all mass-feasible AA substitutions there; pick the one maximising
support + t=0 log-prob. Stop when no improvement.
t0_log_probs: (L, V) log-softmax from denoiser at t=0 for this spectrum.
"""
if not seq or precursor_mass <= 1.0 or len(obs_mz) == 0:
return seq
# Normalise observed intensities
obs_int = obs_int / (obs_int.max() + 1e-9)
for _ in range(max_iter):
b_ions, y_ions = _compute_by_ions(seq)
if len(b_ions) == 0:
break
# Per-position support: average of b[i] and y[len-1-i] support
n = len(seq)
support = np.zeros(n)
for i in range(n):
bi = b_ions[i:i+1] if i < len(b_ions) else np.array([])
yi = y_ions[n-1-i:n-i] if (n-1-i) < len(y_ions) else np.array([])
ions = np.concatenate([bi, yi]) if len(bi) or len(yi) else np.array([])
support[i] = _peak_support(ions, obs_mz, obs_int, ion_tol)
worst_pos = int(np.argmin(support))
best_seq, best_score = seq, support[worst_pos]
for aa in VOCAB:
if aa == seq[worst_pos]:
continue
candidate = seq[:worst_pos] + aa + seq[worst_pos + 1:]
# Mass feasibility
if abs(_seq_mass(candidate) - precursor_mass) > 0.1:
continue
# New support at worst_pos
b_c, y_c = _compute_by_ions(candidate)
if len(b_c) == 0:
continue
bi = b_c[worst_pos:worst_pos+1] if worst_pos < len(b_c) else np.array([])
yi = y_c[n-1-worst_pos:n-worst_pos] if (n-1-worst_pos) < len(y_c) else np.array([])
ions = np.concatenate([bi, yi]) if len(bi) or len(yi) else np.array([])
new_support = _peak_support(ions, obs_mz, obs_int, ion_tol)
# Tiebreak with t=0 log-prob
tok = CHAR_TO_IDX.get(aa, 0)
lp_bonus = float(t0_log_probs[worst_pos, tok].cpu()) if tok < t0_log_probs.shape[1] else 0.0
score = new_support + 0.05 * lp_bonus
if score > best_score:
best_score, best_seq = score, candidate
if best_seq == seq:
break
seq = best_seq
return seq
# ── NOVEL #10: ESM-2 + Spectral Posterior Reranking ──────────────────────────
_esm2_model = None
_esm2_tokenizer = None
def _get_esm2():
global _esm2_model, _esm2_tokenizer
if _esm2_model is None:
try:
from transformers import EsmModel, EsmTokenizer
_esm2_tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
_esm2_model = EsmModel.from_pretrained('facebook/esm2_t6_8M_UR50D')
_esm2_model.eval()
except Exception:
_esm2_model = None
return _esm2_model, _esm2_tokenizer
def esm2_pseudo_perplexity(seqs: list) -> list:
"""Batch ESM-2 pseudo-perplexity using one-fell-swoop masking."""
model, tok = _get_esm2()
if model is None or not seqs:
return [0.0] * len(seqs)
ppls = []
for seq in seqs:
if not seq:
ppls.append(999.0)
continue
try:
import torch
inputs = tok(seq, return_tensors='pt')
ids = inputs['input_ids'][0] # (L+2,)
L = len(ids) - 2
log_prob_sum = 0.0
with torch.no_grad():
for i in range(1, L + 1):
masked = ids.clone()
masked[i] = tok.mask_token_id
out = model(input_ids=masked.unsqueeze(0),
attention_mask=inputs['attention_mask'])
logits = out.last_hidden_state[0, i]
lp = float(torch.log_softmax(logits, dim=-1)[ids[i]])
log_prob_sum += lp
ppls.append(float(np.exp(-log_prob_sum / L)))
except Exception:
ppls.append(999.0)
return ppls
def rank_candidates(candidates: list, spectral_lps: list, obs_mz: np.ndarray,
obs_int: np.ndarray, precursor_mass: float,
lam: float = 0.05, use_esm: bool = True) -> str:
"""
NOVEL #10: score each candidate by:
score(c) = spectral_logprob(c) - lam * esm2_ppl(c) + spectral_support(c)
Returns the best candidate string.
"""
if not candidates:
return ''
if len(candidates) == 1:
return candidates[0]
obs_int_norm = obs_int / (obs_int.max() + 1e-9) if len(obs_int) > 0 else obs_int
ppls = esm2_pseudo_perplexity(candidates) if use_esm else [0.0]*len(candidates)
best_seq, best_score = candidates[0], float('-inf')
for seq, sp_lp, ppl in zip(candidates, spectral_lps, ppls):
b_ions, y_ions = _compute_by_ions(seq)
all_ions = np.concatenate([b_ions, y_ions]) if len(b_ions) > 0 else np.array([])
supp = _peak_support(all_ions, obs_mz, obs_int_norm) if len(all_ions) > 0 else 0.0
score = sp_lp - lam * ppl + supp
if score > best_score:
best_score, best_seq = score, seq
return best_seq
def decode_tokens(tokens) -> str:
result = []
for tok in tokens:
if int(tok) == 2:
break
if int(tok) >= 3:
result.append(IDX_TO_CHAR.get(int(tok), '?'))
return ''.join(result)
@torch.no_grad()
def generate_sequences(encoder, denoiser, spectra, precursor_masses,
n_candidates: int = 5, T_sample: float = 1.0,
t_infer: int = 100, device=None, use_gate: bool = False,
use_beam: bool = False, use_cfid: bool = False):
"""
Iterative reverse diffusion over _INFER_STEPS (20 accelerated steps T→0).
use_beam=True: final step uses mass-constrained beam search (NOVEL #8).
T_sample < 1.0: final step samples with temperature instead of argmax
(used for multi-candidate generation for reranking).
Returns:
sequences: list[N] of list[n_candidates] strings
spectral_lps: list[N] of list[n_candidates] floats
gate_confs: list[N] of list[n_candidates] floats (1.0 = gate not applied)
t0_logits: list[N] of (L, V) tensor at t=0 (last candidate, for SGIR)
"""
if device is None:
device = next(encoder.parameters()).device
encoder.eval(); denoiser.eval()
spec_t = torch.tensor(spectra, dtype=torch.float32, device=device)
mass_t = torch.tensor(precursor_masses, dtype=torch.float32, device=device)
context, peak_pad_mask = encoder(spec_t, mass_t) # (N, K+1, d), (N, K+1)
N = len(spectra)
sequences = [[] for _ in range(N)]
spectral_lps = [[] for _ in range(N)]
gate_confs = [[] for _ in range(N)]
t0_logits = [None] * N
n_steps = len(_INFER_STEPS)
for cand_idx in range(n_candidates):
# Start from fully masked sequence (absorbing diffusion: t=T ≑ all-MASK)
xt = torch.full((N, SEQ_LEN), MASK_TOK, dtype=torch.long, device=device)
self_cond = None # self-conditioning: x0-hat from previous step
for step_idx, t_cur in enumerate(_INFER_STEPS):
t_vec = torch.full((N,), t_cur, dtype=torch.long, device=device)
logits = denoiser(xt, t_vec, context,
peaks=spec_t, precursor_masses=mass_t,
memory_key_padding_mask=peak_pad_mask,
self_cond=self_cond) # (N, L, V)
logits[..., MASK_TOK] = float('-inf') # never predict MASK as output
is_final = (step_idx == n_steps - 1)
if not is_final:
x0_hat = logits.argmax(-1) # (N, L) β€” EOS/PAD allowed
self_cond = x0_hat # carry forward for next step
t_next = _INFER_STEPS[step_idx + 1]
t_next_vec = torch.full((N,), t_next, dtype=torch.long, device=device)
xt = q_sample(x0_hat, t_next_vec)
else:
log_fin = F.log_softmax(logits, dim=-1) # (N, L, V)
# Store t=0 logits from last candidate for SGIR
for i in range(N):
t0_logits[i] = log_fin[i].cpu()
if use_beam:
# NOVEL #8: mass-constrained beam search
beam_toks_list = mass_constrained_beam_search(
logits, mass_t, beam_width=20)
for bi, toks in enumerate(beam_toks_list):
toks_pad = (toks + [0] * SEQ_LEN)[:SEQ_LEN]
tok_t = torch.tensor(toks_pad, dtype=torch.long, device=device)
lp = log_fin[bi].gather(-1, tok_t.unsqueeze(-1)).squeeze(-1)
aa_mask = (tok_t >= 3).float()
sp_lp_i = (lp * aa_mask).sum() / aa_mask.sum().clamp(min=1)
sequences[bi].append(decode_tokens(toks_pad))
spectral_lps[bi].append(float(sp_lp_i.cpu()))
gate_confs[bi].append(1.0)
elif T_sample != 1.0 and n_candidates > 1:
# Temperature sampling: diverse candidates for reranking
scaled = logits / max(T_sample, 1e-6)
probs = F.softmax(scaled, dim=-1)
x0_samp = torch.multinomial(
probs.reshape(-1, VOCAB_SIZE), 1).reshape(N, SEQ_LEN)
sp_lp = log_fin.gather(-1, x0_samp.unsqueeze(-1)).squeeze(-1)
aa_mask = (x0_samp >= 3).float()
sp_lp = (sp_lp * aa_mask).sum(-1) / aa_mask.sum(-1).clamp(min=1)
for i, seq in enumerate(x0_samp.cpu().numpy()):
sequences[i].append(decode_tokens(seq))
spectral_lps[i].append(float(sp_lp[i].cpu()))
gate_confs[i].append(1.0)
elif use_cfid:
# NOVEL #11: Mask-Predict iterative decoding (n_iter=8 for better convergence)
cfid_seqs = mask_predict_decode(logits, context, denoiser, n_iter=8,
peaks=spec_t, precursor_masses=mass_t,
memory_key_padding_mask=peak_pad_mask,
self_cond=x0_hat)
for bi, seq in enumerate(cfid_seqs):
tok_ids = [(CHAR_TO_IDX.get(c, 0) if c in CHAR_TO_IDX else 0)
for c in seq]
tok_t = torch.tensor((tok_ids + [0] * SEQ_LEN)[:SEQ_LEN],
dtype=torch.long, device=device)
lp = log_fin[bi].gather(-1, tok_t.unsqueeze(-1)).squeeze(-1)
aa_mask = (tok_t >= 3).float()
sp_lp_i = (lp * aa_mask).sum() / aa_mask.sum().clamp(min=1)
sequences[bi].append(seq)
spectral_lps[bi].append(float(sp_lp_i.cpu()))
gate_confs[bi].append(1.0)
else:
# Default: argmax
x0_final = logits.argmax(-1) # (N, L)
sp_lp = log_fin.gather(-1, x0_final.unsqueeze(-1)).squeeze(-1)
aa_mask = (x0_final >= 3).float()
sp_lp = (sp_lp * aa_mask).sum(-1) / aa_mask.sum(-1).clamp(min=1)
for i, seq in enumerate(x0_final.cpu().numpy()):
sequences[i].append(decode_tokens(seq))
spectral_lps[i].append(float(sp_lp[i].cpu()))
gate_confs[i].append(1.0)
return sequences, spectral_lps, gate_confs, t0_logits
# ── Save Predictions CSV (deliverable) ────────────────────────────────────────
def save_predictions(sequences, spectral_lps, gate_confs,
out_path='results/diffusion_predictions.csv'):
os.makedirs(os.path.dirname(out_path), exist_ok=True)
rows = []
for i, (seqs, lps, gcs) in enumerate(zip(sequences, spectral_lps, gate_confs)):
for seq, lp, gc in zip(seqs, lps, gcs):
rows.append({'spectrum_id': i, 'sequence': seq,
'spectral_logprob': lp, 'gate_confidence': gc})
df = pd.DataFrame(rows)
df.to_csv(out_path, index=False)
print(f"Saved {len(df)} predictions β†’ {out_path}")
return df
# ── Evaluation ─────────────────────────────────────────────────────────────────
def aa_recall(pred: str, true: str) -> float:
"""Positional recall: fraction of positions where pred[i] == true[i]."""
matches = sum(a == b for a, b in zip(pred, true))
return matches / max(len(true), 1)
def evaluate_aa_recall(encoder, denoiser, X_test, y_test, masses_test,
batch_size=32, results_dir='results', device=None,
use_gate=False, use_beam=False, use_cfid=False,
use_rerank=False, n_rerank=30, T_sample=0.8,
use_sgir=False, raw_peaks=None, use_esm=False,
use_mass_correct=False):
"""
Evaluate de-novo sequencing on a held-out test set.
Flags (can be combined):
use_beam – NOVEL #8: mass-constrained beam search at final step
use_cfid – NOVEL #11: Mask-Predict iterative decoding (bidirectional)
use_rerank – NOVEL #10: generate n_rerank candidates, ESM-2+spectral rerank
use_sgir – NOVEL #9: spectrally-grounded iterative refinement post-decode
requires raw_peaks: list[(mz_arr, int_arr)] aligned with X_test
use_mass_correct – Option B: post-hoc single-swap mass correction
"""
if device is None:
device = next(encoder.parameters()).device
n_cands = n_rerank if use_rerank else 1
t_samp = T_sample if use_rerank else 1.0
all_seqs, all_lps, all_gcs = [], [], []
recalls, pep_correct = [], []
for i in range(0, len(X_test), batch_size):
bs = X_test[i:i+batch_size]
bm = masses_test[i:i+batch_size]
byt = y_test[i:i+batch_size]
brp = raw_peaks[i:i+batch_size] if raw_peaks is not None else None
seqs, lps, gcs, t0_lgs = generate_sequences(
encoder, denoiser, bs, bm,
n_candidates=n_cands, T_sample=t_samp,
device=device, use_gate=use_gate, use_beam=use_beam,
use_cfid=use_cfid)
for j, (pred_list, lp_list, gc_list, true_tok) in enumerate(
zip(seqs, lps, gcs, byt)):
pm = float(bm[j])
obs_mz = brp[j][0] if brp is not None else np.array([])
obs_int = brp[j][1] if brp is not None else np.array([])
# NOVEL #10: rerank candidates
if use_rerank and len(pred_list) > 1:
best = rank_candidates(pred_list, lp_list, obs_mz, obs_int, pm,
use_esm=use_esm)
best_lp = lp_list[pred_list.index(best)] if best in pred_list else lp_list[0]
else:
best = pred_list[0]
best_lp = lp_list[0]
# NOVEL #9: SGIR refinement
if use_sgir and brp is not None and t0_lgs[j] is not None:
best = spectrally_grounded_refine(
best, obs_mz, obs_int, pm, t0_lgs[j])
# Option B: post-hoc single-swap mass correction
if use_mass_correct:
best = mass_correct_sequence(best, pm)
all_seqs.append([best]); all_lps.append([best_lp])
all_gcs.append([gc_list[0]])
true_seq = decode_tokens(true_tok)
recalls.append(aa_recall(best, true_seq))
pep_correct.append(best == true_seq)
aa_rec = float(np.mean(recalls)) * 100
pep_acc = float(np.mean(pep_correct)) * 100
print(f"AA Recall : {aa_rec:.2f}%")
print(f"Peptide Acc: {pep_acc:.2f}%")
os.makedirs(results_dir, exist_ok=True)
save_predictions(all_seqs, all_lps, all_gcs,
os.path.join(results_dir, 'diffusion_predictions.csv'))
return aa_rec, pep_acc
# ── Checkpoint I/O ─────────────────────────────────────────────────────────────
def load_checkpoint(path: str, device=None):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt = torch.load(path, map_location=device, weights_only=False)
encoder = PeakEncoder().to(device)
denoiser = TransformerDenoiser().to(device)
encoder.load_state_dict(ckpt['encoder'])
denoiser.load_state_dict(ckpt['denoiser'])
print(f"Loaded checkpoint from epoch {ckpt.get('epoch', '?')}")
return encoder, denoiser
# ── Main ───────────────────────────────────────────────────────────────────────
if __name__ == '__main__':
BASE = os.path.join(os.path.dirname(__file__), '..', 'Data', 'E coli EV proteomics')
mzml_paths = sorted(glob.glob(os.path.join(BASE, '*.mzML')))
xlsx_paths = sorted(glob.glob(os.path.join(BASE, 'Database search output*.xlsx')))
if not mzml_paths:
raise FileNotFoundError(f"No mzML files in {BASE}")
encoder, denoiser, (X_te, y_te, m_te) = train_diffusion(
mzml_paths, xlsx_paths, checkpoint_dir='checkpoints', epochs=50)
evaluate_aa_recall(encoder, denoiser, X_te, y_te, m_te)