AlienChen/Storage / pCoMole /model /transformer.py
AlienChen's picture
download
raw
17.9 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
# Part of this implementation is adapted from https://github.com/facebookresearch/DiT
# which is released under NonCommercial-4.0 license
# Part of this implementation is adapted from https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
# which is released under MIT license
# Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion
# which is released under MIT license
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention
from einops import rearrange
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from torch import nn, Tensor
from transformers import EsmModel
from . import rotary
import sys
sys.path.append('/scratch/pranamlab/tong/cope/editflows/flow_matching')
import pdb
def lengths_to_offsets(lengths: torch.Tensor) -> torch.Tensor:
# lengths: (B,) long -> offsets: (B+1,) long
return F.pad(lengths.cumsum(0), (1, 0))
def build_seq_ids(lengths: torch.Tensor, device=None) -> torch.Tensor:
# lengths: (B,) -> seq_ids: (T,) where T=sum(lengths)
device = device or lengths.device
return torch.repeat_interleave(torch.arange(lengths.numel(), device=device), lengths)
def make_score_mod_for_intra_sequence_only(lengths: torch.Tensor):
seq_ids = build_seq_ids(lengths, device=lengths.device) # (T,)
def score_mod(scores, b, h, q_idx, k_idx):
# scores: (..., q_block, k_block)
same = (seq_ids[q_idx] == seq_ids[k_idx]) # bool, broadcastable to scores
# Set cross-sequence scores to a very negative value *without* in-place ops
neg_large = torch.finfo(scores.dtype).min # dtype-safe "−large" (no -inf issues)
return torch.where(same, scores, torch.full_like(scores, neg_large))
return score_mod
# def make_score_mod_for_intra_sequence_only(lengths: torch.Tensor):
# """
# Returns a score_mod callback for FlexAttention that sets logits to -inf
# when query and key belong to different sequences (ragged, no padding).
# """
# seq_ids = build_seq_ids(lengths, device=lengths.device) # (T,)
# def score_mod(scores, b, h, q_idx, k_idx):
# # scores: (..., q_block, k_block); q_idx/k_idx: flat token indices in [0..T-1]
# original_type = scores.dtype
# scores = scores.float()
# same = (seq_ids[q_idx] == seq_ids[k_idx]).to(scores.dtype)
# scores += (same - 1) * 1e9 # subtract a large number for cross-seq pairs
# return scores.to(original_type)
# return score_mod
def bias_dropout_add_scale(
x: Tensor, scale: Tensor, residual: Optional[Tensor], prob: float, training: bool
) -> Tensor:
return residual + scale * F.dropout(x, p=prob, training=training)
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
return x * (1 + scale) + shift
class LayerNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.ones([dim]))
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
with torch.amp.autocast("cuda", enabled=False):
y = F.layer_norm(x.float(), [self.dim])
if y.dim() == 3: # (B,S,H)
scale = self.weight[None, None, :]
elif y.dim() == 2: # (T,H)
scale = self.weight[None, :]
else:
raise ValueError(f"LayerNorm expects 2D/3D, got {y.shape}")
return y * scale
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(time: Tensor, dim: int, max_period: int = 10000) -> Tensor:
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=time.device)
args = time[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, time: Tensor) -> Tensor:
t_freq = self.timestep_embedding(time=time, dim=self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class DDiTBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
cond_dim: int,
mlp_ratio: int = 4,
dropout: float = 0.1,
):
super().__init__()
assert dim % n_heads == 0, "dim must be devisable by n_heads"
self.n_heads = n_heads
self.dim = dim
self.dropout = dropout
self.head_dim = self.dim // self.n_heads
self.norm1 = LayerNorm(dim=dim)
self.qw = nn.Linear(dim, dim, bias=False)
self.kw = nn.Linear(dim, dim, bias=False)
self.vw = nn.Linear(dim, dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.dropout1 = nn.Dropout(dropout)
self.norm2 = LayerNorm(dim=dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_ratio * dim, dim, bias=True),
)
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def forward(
self,
x: Tensor, # (T, H) flat tokens; T = sum(lengths)
rotary_caches: tuple, # (cos, sin)
c: Tensor # (B, cond_dim)
) -> Tensor:
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
seq_ids = build_seq_ids(lengths, device=x.device) # (T,)
shift_msa_b = shift_msa.squeeze(1)
scale_msa_b = scale_msa.squeeze(1)
gate_msa_b = gate_msa.squeeze(1)
shift_mlp_b = shift_mlp.squeeze(1)
scale_mlp_b = scale_mlp.squeeze(1)
gate_mlp_b = gate_mlp.squeeze(1)
# Token-wise (T,H)
shift_msa_t = shift_msa_b[seq_ids]
scale_msa_t = scale_msa_b[seq_ids]
gate_msa_t = gate_msa_b[seq_ids]
shift_mlp_t = shift_mlp_b[seq_ids]
scale_mlp_t = scale_mlp_b[seq_ids]
gate_mlp_t = gate_mlp_b[seq_ids]
x_skip = x
x = modulate(self.norm1(x), shift=shift_msa_t, scale=scale_msa_t) # (T, H)
q = self.qw(x); k = self.kw(x); v = self.vw(x) # (T, H)
T = x.shape[0]
q = q.view(T, self.n_heads, self.head_dim)
k = k.view(T, self.n_heads, self.head_dim)
v = v.view(T, self.n_heads, self.head_dim)
with torch.amp.autocast("cuda", enabled=False):
##################
# TODO:
##################
q = q.to(dtype); k = k.to(dtype)
# fold heads into batch
q = q.transpose(0, 1).contiguous() # (Hh, T, Dh)
k = k.transpose(0, 1).contiguous()
v = v.transpose(0, 1).contiguous()
q = q.unsqueeze(0) # -> (1, Hh, T, Dh)
k = k.unsqueeze(0) # -> (1, Hh, T, Dh)
v = v.unsqueeze(0) # -> (1, Hh, T, Dh)
score_mod = make_score_mod_for_intra_sequence_only(lengths)
attn_out = flex_attention(q, k, v, score_mod=score_mod) # (Hh, T, Dh)
attn_out = attn_out.squeeze(0) # -> (Hh, T, Dh)
x = attn_out.transpose(0, 1).contiguous().view(T, self.dim) # (T,H)
x = bias_dropout_add_scale(self.attn_out(x), gate_msa_t, x_skip, self.dropout, self.training)
x = bias_dropout_add_scale(
self.mlp(modulate(self.norm2(x), shift=shift_mlp_t, scale=scale_mlp_t)),
gate_mlp_t, x, self.dropout, self.training
)
return x
class DDitFinalLayer(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, cond_dim: int):
super().__init__()
self.norm_final = LayerNorm(hidden_size)
self.linear = nn.Linear(hidden_size, out_channels)
self.linear.weight.data.zero_()
self.linear.bias.data.zero_()
self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def forward(self, x: Tensor, c: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
x = modulate(x=self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
# -------------------------------------------------------------
# NEW: small helper to keep λ ≥ 0
class Positive(nn.Module):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.tensor(0.0))
def forward(self, x: Tensor) -> Tensor:
return F.softplus(x + self.bias)
# NEW: project token hidden states to "slot" hidden states (n+1)
class SlotProjector(nn.Module):
"""
Builds n+1 slot states from n token states using learnable boundaries.
slot i uses left h_{i-1} and right h_i (with learned BOS/EOS).
"""
def __init__(self, hidden_size: int):
super().__init__()
self.bos = nn.Parameter(torch.zeros(hidden_size))
self.eos = nn.Parameter(torch.zeros(hidden_size))
self.proj = nn.Linear(2 * hidden_size, hidden_size, bias=True)
self.act = nn.GELU()
def forward(self, h_tok: Tensor) -> Tensor:
# h_tok: (B, n, H)
B, n, H = h_tok.shape
bos = self.bos.expand(B, 1, H) # (B, 1, H)
eos = self.eos.expand(B, 1, H) # (B, 1, H)
# left/right neighbors for the n+1 between-token slots
left = torch.cat([bos, h_tok], dim=1) # (B, n+1, H)
right = torch.cat([h_tok, eos], dim=1) # (B, n+1, H)
slots = torch.cat([left, right], dim=-1) # (B, n+1, 2H)
return self.act(self.proj(slots)) # (B, n+1, H)
# NEW: multi-head output layer for edit flows
class EditFlowsHead(nn.Module):
"""
Produces:
- λ_ins: (B, n+1), Q_ins: (B, n+1, V)
- λ_del: (B, n), λ_sub: (B, n), Q_sub: (B, n, V)
"""
def __init__(self, hidden_size: int, vocab_size: int, cond_dim: int):
super().__init__()
self.norm = LayerNorm(hidden_size)
self.cond = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
self.cond.weight.data.zero_()
self.cond.bias.data.zero_()
# token-position heads (n)
self.lambda_del = nn.Linear(hidden_size, 1, bias=True)
self.lambda_sub = nn.Linear(hidden_size, 1, bias=True)
self.q_sub = nn.Linear(hidden_size, vocab_size, bias=True)
# slot-position heads (n+1)
self.slot_proj = SlotProjector(hidden_size)
self.lambda_ins = nn.Linear(hidden_size, 1, bias=True)
self.q_ins = nn.Linear(hidden_size, vocab_size, bias=True)
self.to_positive = Positive()
# init: keep outputs small at start
for m in [self.lambda_del, self.lambda_sub, self.lambda_ins, self.q_sub, self.q_ins]:
nn.init.zeros_(m.weight); nn.init.zeros_(m.bias)
def forward(self, h_tok: Tensor, c: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# FiLM-style modulation like your DDitFinalLayer
shift, scale = self.cond(c)[:, None].chunk(2, dim=-1)
h_tok = modulate(self.norm(h_tok), shift=shift, scale=scale) # (B, n, H)
# token positions (n)
lam_del = self.to_positive(self.lambda_del(h_tok)).squeeze(-1) # (B, n)
lam_sub = self.to_positive(self.lambda_sub(h_tok)).squeeze(-1) # (B, n)
q_sub = self.q_sub(h_tok) # (B, n, V) -- removed softmax here, may need to adjust tensor shape
# slot positions (n+1)
h_slot = self.slot_proj(h_tok) # (B, n+1, H)
lam_ins = self.to_positive(self.lambda_ins(h_slot)).squeeze(-1) # (B, n+1)
q_ins = self.q_ins(h_slot) # (B, n+1, V) -- removed softmax here, may need to adjust tensor shape
return lam_ins, q_ins, lam_del, lam_sub, q_sub
class Transformer(nn.Module):
def __init__(self, vocab_size: int, masked: bool, config: DictConfig):
super().__init__()
if isinstance(config, dict):
config = OmegaConf.create(config)
self.config = config
self.vocab_size = vocab_size
self.pad_id = getattr(config, "pad_id", 0)
add_token = 1 if masked else 0 # keep if you need a mask token elsewhere
# ESM-2 embedding approach (similar to gpm_model.py)
esm_model_name = getattr(config, "esm_model_name", "facebook/esm2_t33_650M_UR50D")
freeze_esm = getattr(config, "freeze_esm", True)
if esm_model_name is not None:
self.tok_embedder = EsmModel.from_pretrained(esm_model_name)
tok_embed_dim = self.tok_embedder.config.hidden_size
if freeze_esm:
for param in self.tok_embedder.parameters():
param.requires_grad = False
self.tok_embedder.eval()
# Project from ESM hidden size to model hidden size
self.tok_embed_to_hidden = nn.Linear(tok_embed_dim, config.hidden_size)
self.vocab_embed = None # Not needed when using ESM
else:
# Fallback to embedding layer if ESM is not used
self.tok_embedder = None
self.tok_embed_to_hidden = None # Not needed when using standard embedding
self.vocab_embed = nn.Embedding(self.vocab_size + add_token, config.hidden_size)
self.time_embedding = TimestepEmbedder(hidden_size=config.cond_dim)
self.rotary_emb = rotary.Rotary(dim=config.hidden_size // config.n_heads)
self.blocks = nn.ModuleList(
[
DDiTBlock(
dim=config.hidden_size,
n_heads=config.n_heads,
cond_dim=config.cond_dim,
dropout=config.dropout,
)
for _ in range(config.n_blocks)
]
)
# CHANGED: use EditFlowsHead instead of DDitFinalLayer
self.output_layer = EditFlowsHead(
hidden_size=config.hidden_size,
vocab_size=vocab_size + add_token,
cond_dim=config.cond_dim,
)
def _embed_ragged(self, x_t, mask):
# Handle device - use ESM device or vocab_embed device
if self.tok_embedder is not None:
device = next(self.tok_embedder.parameters()).device
else:
device = self.vocab_embed.weight.device
if self.tok_embedder is not None:
x_t_emb = self.tok_embedder(x_t, mask).last_hidden_state
x = self.tok_embed_to_hidden(x_t_emb)
else:
# Fallback to standard embedding
x = self.vocab_embed(x_t) # (T, H)
# positions from RoPE helper:
B, S = x_t.shape
lengths = torch.full((B,), S, device=device, dtype=torch.long)
positions = self.rotary_emb.positions_like(lengths) # (T,)
return x, lengths, positions
def forward(self, x_t, time: torch.Tensor):
"""
x_t: List[LongTensor] of variable lengths OR a (B,S) tensor
time: (B,) float in [0,1]
returns lists (ragged): lam_ins, q_ins, lam_del, lam_sub, q_sub
"""
# 0) pad x_t
x_t = torch.nn.utils.rnn.pad_sequence(x_t, batch_first=True, padding_value=self.pad_id)
mask = (x_t != self.pad_id).long()
# 1) embed ragged
x_emb = self._embed_ragged(x_t, mask)
# 2) time conditioning
B = x_t.shape[0]
t_emb = F.silu(self.time_embedding(time)) # (B, cond_dim)
# 3) rotary caches (you can keep your existing cache logic)
###############
# TODO:
###############
# 4) transformer blocks
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
for blk in self.blocks:
x = blk(
x=x_emb,
mask=mask,
rotary_caches=rotary_caches,
c=t_emb,
)
#############
# TODO:
#############
lam_ins, q_ins, lam_del, lam_sub, q_sub = zip(*outs)
# 6) return ragged lists (keep padding outside the model if needed)
return list(lam_ins), list(q_ins), list(lam_del), list(lam_sub), list(q_sub)

Xet Storage Details

Size:
17.9 kB
·
Xet hash:
b52f59d10053b69c4bc563dd007471162b020e1236d69a56c52e094c7efa9db8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.