AlienChen/Storage / pCoMole /model /base_models.py
AlienChen's picture
download
raw
21.2 kB
from typing import Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import EsmModel
import pytorch_lightning as pl
from .utils import build_z0_z1_with_alignment, remove_eps
import pdb
# ---------- Utilities ----------
def exists(x): return x is not None
def default(val, d):
return val if exists(val) else d
# ---------- Timestep embedding (sinusoidal -> MLP) ----------
class TimeEmbedding(nn.Module):
"""
Sinusoidal time embedding followed by a small MLP.
Accepts t of shape (B,) or scalar; outputs (B, d_model) and broadcasts over L.
"""
def __init__(self, d_model: int, hidden: Optional[int] = None, max_period: int = 10000):
super().__init__()
self.d_model = d_model
self.max_period = max_period
hidden = default(hidden, d_model * 4)
# Use even dim for sin/cos pairing
pe_dim = d_model if d_model % 2 == 0 else d_model - 1
self.pe_dim = pe_dim
self.mlp = nn.Sequential(
nn.Linear(pe_dim, hidden),
nn.SiLU(),
nn.Linear(hidden, d_model),
)
def forward(self, t: torch.Tensor, batch_size: Optional[int] = None) -> torch.Tensor:
"""
t: (B,) or () in [0,1]
returns: (B, d_model)
"""
if t.dim() == 0:
# scalar -> expand to batch
if batch_size is None:
raise ValueError("When t is scalar, provide batch_size.")
t = t.expand(batch_size)
B = t.shape[0]
device = t.device
half = self.pe_dim // 2
# frequencies
freqs = torch.exp(
torch.arange(half, device=device, dtype=t.dtype) * (-math.log(self.max_period) / (half - 1 + 1e-8))
)
angles = t[:, None] * freqs[None, :] * math.pi # (B, half)
pe = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) # (B, pe_dim)
if self.pe_dim < self.d_model: # rare case when d_model is odd
pe = F.pad(pe, (0, 1), value=0.0)
return self.mlp(pe) # (B, d_model)
# ---------- RoPE (rotary position embedding) ----------
def apply_rotary(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
q, k: (B, h, L, d_head)
cos, sin: (L, d_head) broadcastable to (B, h, L, d_head)
"""
# split last dim into pairs
d = q.shape[-1]
if d % 2 != 0:
# pad to even
q = F.pad(q, (0, 1), value=0.0)
k = F.pad(k, (0, 1), value=0.0)
d += 1
q1, q2 = q[..., :d//2], q[..., d//2:]
k1, k2 = k[..., :d//2], k[..., d//2:]
# broadcast cos/sin
while cos.dim() < q1.dim():
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
rq = torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1)
rk = torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1)
return rq[..., :q.shape[-1]], rk[..., :k.shape[-1]]
class RotaryPositionalEmbedding(nn.Module):
"""
Precomputes cos/sin for RoPE given max_len and head_dim.
"""
def __init__(self, head_dim: int, max_len: int = 8192, base: int = 10000):
super().__init__()
if head_dim % 2 != 0:
# allow odd by padding inside apply_rotary, but prefer even
pass
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_len = max_len
self.head_dim = head_dim
self._cached_len = 0
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
def _update_cache(self, seq_len: int, device, dtype):
if seq_len <= self._cached_len and self.cos_cached.device == device and self.cos_cached.dtype == dtype:
return
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('n,d->nd', t, self.inv_freq) # (L, head_dim/2)
# emb = torch.cat((freqs, freqs), dim=-1) # (L, head_dim)
self.cos_cached = freqs.cos().to(dtype=dtype)
self.sin_cached = freqs.sin().to(dtype=dtype)
self._cached_len = seq_len
def forward(self, L: int, device, dtype):
self._update_cache(L, device, dtype)
return self.cos_cached[:L], self.sin_cached[:L]
# ---------- Rotary MHA + Transformer block ----------
class RotaryMHA(nn.Module):
def __init__(self, d_model: int, n_heads: int, rope: RotaryPositionalEmbedding, attn_dropout: float = 0.0, proj_dropout: float = 0.0):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.attn_dropout = nn.Dropout(attn_dropout)
self.proj_dropout = nn.Dropout(proj_dropout)
self.rope = rope
def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
"""
x: (B, L, d_model)
key_padding_mask: (B, L) bool, True=pad (masked)
"""
B, L, D = x.shape
qkv = self.qkv(x) # (B, L, 3D)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, L, self.n_heads, self.d_head).transpose(1, 2) # (B, h, L, d)
k = k.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
v = v.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
cos, sin = self.rope(L, x.device, x.dtype)
q, k = apply_rotary(q, k, cos, sin)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head) # (B, h, L, L)
if exists(key_padding_mask):
# mask: True means pad -> set attention to -inf
mask = key_padding_mask[:, None, None, :].to(dtype=torch.bool) # (B,1,1,L)
attn_scores = attn_scores.masked_fill(mask, float('-inf'))
attn = torch.softmax(attn_scores, dim=-1)
attn = self.attn_dropout(attn)
out = torch.matmul(attn, v) # (B, h, L, d)
out = out.transpose(1, 2).contiguous().view(B, L, D)
return self.proj_dropout(self.proj(out))
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, attn_dropout: float = 0.0, proj_dropout: float = 0.0, rope: Optional[RotaryPositionalEmbedding] = None):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = RotaryMHA(d_model, n_heads, rope=rope, attn_dropout=attn_dropout, proj_dropout=proj_dropout)
self.norm2 = nn.LayerNorm(d_model)
hidden = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, hidden),
nn.SiLU(),
nn.Linear(hidden, d_model),
)
def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
x = x + self.attn(self.norm1(x), key_padding_mask=key_padding_mask)
x = x + self.mlp(self.norm2(x))
return x
class ProteinEditFlowModel(nn.Module):
"""
Inputs:
x_t: (B, L) Long
mask: (B, L) bool, True=pad (i.e., should be ignored)
t: (B,) or scalar in [0,1]
Outputs:
lam_ins: (B, L) >= 0
logits_ins: (B, L, V)
lam_del: (B, L) >= 0
lam_sub: (B, L) >= 0
logits_sub: (B, L, V)
"""
def __init__(self, vocab_size, pad_id, config):
super().__init__()
self.d_model = getattr(config, "d_model", 768)
self.n_layers = getattr(config, "n_layers", 12)
self.n_heads = getattr(config, "n_heads", 12)
self.mlp_ratio = getattr(config, "mlp_ratio", 4)
self.max_len = getattr(config, "max_len", 2048)
self.dropout = getattr(config, "dropout", 0.1)
self.attn_dropout = getattr(config, "attn_dropout", 0)
self.proj_dropout = getattr(config, "proj_dropout", 0)
self.vocab_size = vocab_size
self.pad_id = pad_id
# --- Embedding ---
self.esm_emb = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
for param in self.esm_emb.parameters():
param.requires_grad = False
self.time_emb = TimeEmbedding(d_model=self.d_model)
self.tok_embed_to_hidden = nn.Linear(1280, self.d_model)
# --- RoPE shared by attention blocks ---
rope = RotaryPositionalEmbedding(head_dim=self.d_model // self.n_heads, max_len=self.max_len)
# --- Encoder ---
self.blocks = nn.ModuleList([
TransformerBlock(
d_model=self.d_model,
n_heads=self.n_heads,
mlp_ratio=self.mlp_ratio,
attn_dropout=self.attn_dropout,
proj_dropout=self.proj_dropout,
rope=rope
)
for _ in range(self.n_layers)
])
self.final_norm = nn.LayerNorm(self.d_model)
# --- Heads ---
# We use small MLP heads for rates; logits are linear.
self.lam_ins_head = nn.Sequential(nn.Linear(self.d_model, self.d_model//2), nn.SiLU(), nn.Linear(self.d_model//2, 1))
self.lam_del_head = nn.Sequential(nn.Linear(self.d_model, self.d_model//2), nn.SiLU(), nn.Linear(self.d_model//2, 1))
self.lam_sub_head = nn.Sequential(nn.Linear(self.d_model, self.d_model//2), nn.SiLU(), nn.Linear(self.d_model//2, 1))
self.logits_ins_head = nn.Linear(self.d_model, vocab_size, bias=False)
self.logits_sub_head = nn.Linear(self.d_model, vocab_size, bias=False)
# nonnegativity via softplus (safer than exp)
self.softplus = nn.Softplus(beta=1.0)
def forward(
self,
x_t: torch.LongTensor,
mask: torch.BoolTensor,
t: torch.Tensor,
):
"""
x_t: (B, L) long tokens
mask: (B, L) bool, True = PAD (ignored)
t: (B,) or scalar in [0,1]
esmbed: optional (B, L, esm2_embed_dim) if use_real_esm2=True
"""
B, L = x_t.shape
# pdb.set_trace()
# --- Embedding ---
h = self.esm_emb(x_t, mask).last_hidden_state
h = self.tok_embed_to_hidden(h)
# --- Add time embedding (broadcast across length) ---
t_emb = self.time_emb(t, batch_size=B) # (B, d_model)
h = h + t_emb.unsqueeze(1) # (B, L, d_model)
# --- Encoder blocks with key padding mask ---
for blk in self.blocks:
h = blk(h, key_padding_mask=(~mask))
h = self.final_norm(h) # (B, L, d_model)
# --- Heads ---
lam_ins = self.softplus(self.lam_ins_head(h)).squeeze(-1) # (B, L)
lam_del = self.softplus(self.lam_del_head(h)).squeeze(-1) # (B, L)
lam_sub = self.softplus(self.lam_sub_head(h)).squeeze(-1) # (B, L)
logits_ins = self.logits_ins_head(h) # (B, L, V)
logits_sub = self.logits_sub_head(h) # (B, L, V)
# --- Zero-out padded positions so they contribute nothing downstream ---
if exists(mask):
# For lambdas: force to 0 on pads
pad_mask_f = mask.to(h.dtype) # True=valid -> 1.0
lam_ins = lam_ins * pad_mask_f
lam_del = lam_del * pad_mask_f
lam_sub = lam_sub * pad_mask_f
# kill logits on pads
neg_val = torch.tensor(-1e4, device=h.device, dtype=h.dtype)
logits_ins = logits_ins.masked_fill((~mask).unsqueeze(-1), neg_val)
logits_sub = logits_sub.masked_fill((~mask).unsqueeze(-1), neg_val)
return lam_ins, logits_ins, lam_del, lam_sub, logits_sub
class SMILESEditFlowModel(nn.Module):
"""
Inputs:
x_t: (B, L) Long
mask: (B, L) bool, True=pad (i.e., should be ignored)
t: (B,) or scalar in [0,1]
Outputs:
lam_ins: (B, L) >= 0
logits_ins: (B, L, V)
lam_del: (B, L) >= 0
lam_sub: (B, L) >= 0
logits_sub: (B, L, V)
"""
def __init__(self, vocab_size, pad_id, config):
super().__init__()
self.d_model = getattr(config, "d_model", 768)
self.n_layers = getattr(config, "n_layers", 12)
self.n_heads = getattr(config, "n_heads", 12)
self.mlp_ratio = getattr(config, "mlp_ratio", 4)
self.max_len = getattr(config, "max_len", 2048)
self.dropout = getattr(config, "dropout", 0.1)
self.attn_dropout = getattr(config, "attn_dropout", 0)
self.proj_dropout = getattr(config, "proj_dropout", 0)
self.vocab_size = vocab_size
self.pad_id = pad_id
# --- Embedding ---
self.seq_emb = nn.Embedding(self.vocab_size, self.d_model, padding_idx=self.pad_id)
self.time_emb = TimeEmbedding(d_model=self.d_model)
# --- RoPE shared by attention blocks ---
rope = RotaryPositionalEmbedding(head_dim=self.d_model // self.n_heads, max_len=self.max_len)
# --- Encoder ---
self.blocks = nn.ModuleList([
TransformerBlock(
d_model=self.d_model,
n_heads=self.n_heads,
mlp_ratio=self.mlp_ratio,
attn_dropout=self.attn_dropout,
proj_dropout=self.proj_dropout,
rope=rope
)
for _ in range(self.n_layers)
])
self.final_norm = nn.LayerNorm(self.d_model)
# --- Heads ---
# We use small MLP heads for rates; logits are linear.
self.lam_ins_head = nn.Sequential(nn.Linear(self.d_model, self.d_model//2), nn.SiLU(), nn.Linear(self.d_model//2, 1))
self.lam_del_head = nn.Sequential(nn.Linear(self.d_model, self.d_model//2), nn.SiLU(), nn.Linear(self.d_model//2, 1))
self.lam_sub_head = nn.Sequential(nn.Linear(self.d_model, self.d_model//2), nn.SiLU(), nn.Linear(self.d_model//2, 1))
self.logits_ins_head = nn.Linear(self.d_model, vocab_size, bias=False)
self.logits_sub_head = nn.Linear(self.d_model, vocab_size, bias=False)
# nonnegativity via softplus (safer than exp)
self.softplus = nn.Softplus(beta=1.0)
def forward(
self,
x_t: torch.LongTensor,
mask: torch.BoolTensor,
t: torch.Tensor,
):
"""
x_t: (B, L) long tokens
mask: (B, L) bool, True = PAD (ignored)
t: (B,) or scalar in [0,1]
esmbed: optional (B, L, esm2_embed_dim) if use_real_esm2=True
"""
B, L = x_t.shape
# --- Embedding ---
h = self.seq_emb(x_t)
# --- Add time embedding (broadcast across length) ---
t_emb = self.time_emb(t, batch_size=B) # (B, d_model)
h = h + t_emb.unsqueeze(1) # (B, L, d_model)
# --- Encoder blocks with key padding mask ---
for blk in self.blocks:
h = blk(h, key_padding_mask=(~mask))
h = self.final_norm(h) # (B, L, d_model)
# --- Heads ---
lam_ins = self.softplus(self.lam_ins_head(h)).squeeze(-1) # (B, L)
lam_del = self.softplus(self.lam_del_head(h)).squeeze(-1) # (B, L)
lam_sub = self.softplus(self.lam_sub_head(h)).squeeze(-1) # (B, L)
logits_ins = self.logits_ins_head(h) # (B, L, V)
logits_sub = self.logits_sub_head(h) # (B, L, V)
# --- Zero-out padded positions so they contribute nothing downstream ---
if exists(mask):
# For lambdas: force to 0 on pads
pad_mask_f = mask.to(h.dtype) # True=valid -> 1.0
lam_ins = lam_ins * pad_mask_f
lam_del = lam_del * pad_mask_f
lam_sub = lam_sub * pad_mask_f
# kill logits on pads
neg_val = torch.tensor(-1e4, device=h.device, dtype=h.dtype)
logits_ins = logits_ins.masked_fill((~mask).unsqueeze(-1), neg_val)
logits_sub = logits_sub.masked_fill((~mask).unsqueeze(-1), neg_val)
return lam_ins, logits_ins, lam_del, lam_sub, logits_sub
class EditFlow(pl.LightningModule):
def __init__(self,
model,
loss_fn,
path,
source_distribution,
pad_id,
bos_id,
eos_id,
config,
):
super().__init__()
self.cfg = config
self.source_distribution = source_distribution
self.path = path
self.model = model
self.loss_fn = loss_fn
self.bos_id = bos_id
self.eos_id = eos_id
self.pad_id = pad_id
self.eps_id = getattr(self.path, "eps_id", -1)
self._total_steps = None
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(),
lr=float(self.cfg.optim.lr),
betas=(self.cfg.optim.beta1, self.cfg.optim.beta2),
eps=float(self.cfg.optim.eps),
weight_decay=self.cfg.optim.weight_decay,
fused=self.cfg.optim.fused,
)
warmup_ratio = getattr(self.cfg.optim, "warmup_ratio", 0.1)
min_scale = 0.1
def lr_lambda(global_step: int):
# until on_train_start runs we just return 1.0
if self._total_steps is None or self._total_steps == 0:
return 1.0
total_steps = self._total_steps
warmup_steps = max(1, int(warmup_ratio * total_steps))
if global_step < warmup_steps:
# linear warmup: 0.1 -> 1.0
alpha = (global_step + 1) / warmup_steps
return 0.1 + 0.9 * alpha
else:
# cosine from 1.0 down to min_scale
progress = (global_step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine = 0.5 * (1 + math.cos(math.pi * progress)) # 1 -> 0
return min_scale + (1.0 - min_scale) * cosine
sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
return {
"optimizer": opt,
"lr_scheduler": {
"scheduler": sch,
"interval": "step", # <- per-step
"frequency": 1,
},
}
def preparation(self, x_1):
B = x_1.shape[0]
with torch.no_grad():
allowed_tokens = torch.tensor([tok for tok in self.source_distribution._allowed_tokens if tok != self.eps_id]).to(self.device)
x_0 = self.source_distribution.sample_x0_from_x1(x_1, pad_id=self.pad_id, allowed_tokens=allowed_tokens, scale_size=self.cfg.model.scale_size, bos_id = self.bos_id, eos_id = self.eos_id)
t = torch.rand(B, device=self.device)
sched = self.path.scheduler(t)
weight = sched.d_alpha_t / sched.sigma_t # (B,)
z_0, z_1 = build_z0_z1_with_alignment(x_0, x_1, self.eps_id, self.pad_id, self.bos_id, self.eos_id, p_optimal=self.cfg.model.p_optimal)
z_t = self.path.sample(z_0, z_1, t=t)
x_t, mask = remove_eps(z_t, self.eps_id, self.pad_id)
lam_ins, logits_ins, lam_del, lam_sub, logits_sub = self.model(x_t=x_t, mask=mask,t=t)
return lam_ins, logits_ins, lam_del, lam_sub, logits_sub, z_t, z_1, x_t, mask, weight
def training_step(self, batch, batch_idx):
x_1 = torch.tensor(batch["input_ids"]).to(self.device)
B = x_1.shape[0]
lam_ins, logits_ins, lam_del, lam_sub, logits_sub, z_t, z_1, x_t, mask, weight = self.preparation(x_1)
loss = self.loss_fn(lam_ins, logits_ins, lam_del, lam_sub, logits_sub,
z_t, z_1, x_t, mask, weight, self.eps_id, self.bos_id, self.eos_id)
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=B, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
x_1 = torch.tensor(batch["input_ids"]).to(self.device)
B = x_1.shape[0]
lam_ins, logits_ins, lam_del, lam_sub, logits_sub, z_t, z_1, x_t, mask, weight = self.preparation(x_1)
loss = self.loss_fn(lam_ins, logits_ins, lam_del, lam_sub, logits_sub,
z_t, z_1, x_t, mask, weight, self.eps_id, self.bos_id, self.eos_id)
self.log("val_loss", loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=B, sync_dist=True)
return loss
def test_step(self, batch, batch_idx):
x_1 = torch.tensor(batch["input_ids"]).to(self.device)
B = x_1.shape[0]
lam_ins, logits_ins, lam_del, lam_sub, logits_sub, z_t, z_1, x_t, mask, weight = self.preparation(x_1)
loss = self.loss_fn(lam_ins, logits_ins, lam_del, lam_sub, logits_sub,
z_t, z_1, x_t, mask, weight, self.eps_id, self.bos_id, self.eos_id)
self.log("test_loss", loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=B, sync_dist=True)
return loss
def on_train_start(self):
# how many optimizer steps we will take in this fit
self._total_steps = self.trainer.estimated_stepping_batches

Xet Storage Details

Size:
21.2 kB
·
Xet hash:
e7c58604c9b32d720870b71e8e85b383460fb2df153375d2a6e6674b6a5b611f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.