supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
28.7 kB
"""
Bertint V8 — Cross-Attention + Live Bertose Finetuning
Architecture:
GLYCAN: WURCS → BPE → Bertose (live, freeze layers 0-3) → [B, Lg, 768]
↓ proj(768→512)
PROTEIN: precomputed ESM-C → [B, Lp, 960] ↓
↓ proj(960→512) ↓
2× CrossAttentionBlock(d=512, 8 heads, FFN=1024) ↓
↓ ↓
SHARED mask-aware SWE(d=512, S=512, R=64) ↓
↓ ↓
[B, 512] [B, 512]
↓ element-wise product + sum
[B, 1024]
↓ MLP → binding score
Key changes from V7:
- Per-residue protein embeddings (not mean-pooled) for cross-attention
- CrossAttentionBlock: glycan tokens attend to protein residues and vice versa
- SWE pooling: variable-length → fixed-length (mask-aware, differentiable)
- Product + sum interaction (from Twin Peaks) instead of concat
Key changes from V3:
- Live Bertose forward pass (not frozen precomputed embeddings)
- ESM-C 300M (960-dim, not ESM-C 600M 1152-dim)
SWE, CrossAttention, and mask handling ported from V3 (Sessions 5-10).
"""
import os
import sys
import math
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================================
# Bertose model imports (same as V7)
# ============================================================================
def _default_bertose_root() -> Path:
"""Resolve the Bertose source root without assuming a Nova-only path."""
env_root = os.environ.get("BERTOSE_ROOT") or os.environ.get("BERTOSE_REPO_ROOT")
if env_root:
return Path(env_root).expanduser().resolve()
here = Path(__file__).resolve()
for parent in here.parents:
if (parent / "bert_training_v4").exists() and (parent / "model").exists():
return parent
return Path("/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training")
BERTOSE_ROOT = _default_bertose_root()
def _ensure_bertose_imports():
"""Add Bertose source directories to sys.path if not already present."""
roots = [
str(BERTOSE_ROOT),
str(BERTOSE_ROOT / "bert_training_v4"),
]
for root in roots:
if root not in sys.path:
sys.path.insert(0, root)
def load_bertose_config():
"""Create Bertose config matching the V5b checkpoint."""
_ensure_bertose_imports()
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERTConfig
return MultimodalGlycanBERTConfig(
seq_vocab_size=2200,
use_cnn_frontend=True,
)
def load_bertose_encoder(
checkpoint_path: str, freeze_layers: int = 4
):
"""
Load Bertose sequence encoder with pretrained weights.
Args:
checkpoint_path: Path to pretrained Bertose checkpoint.
freeze_layers: Number of transformer layers to freeze (0-indexed).
Returns:
Tuple of (bertose_config, seq_embeddings, seq_layers).
"""
_ensure_bertose_imports()
from model.multimodal_glycan_bert_v3 import (
MultimodalGlycanBERT,
MultimodalGlycanBERTConfig,
)
# Load checkpoint
ckpt = torch.load(checkpoint_path, map_location="cpu")
state_dict = ckpt.get("model_state_dict", ckpt)
# Infer vocab size and max position embeddings from checkpoint
vocab_size = state_dict["seq_embeddings.token_embeddings.weight"].shape[0]
max_pos = state_dict["seq_embeddings.position_embeddings.weight"].shape[0]
config = MultimodalGlycanBERTConfig(
seq_vocab_size=vocab_size,
seq_max_length=max_pos,
use_cnn_frontend=True,
)
# Instantiate full model, then extract sequence encoder
model = MultimodalGlycanBERT(config)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
loaded = len(state_dict) - len(unexpected)
print(f" Loaded {loaded}/{len(state_dict)} pretrained weight tensors")
print(f" ({len(missing)} missing in checkpoint, {len(unexpected)} unexpected)")
seq_embeddings = model.seq_embeddings
seq_layers = model.seq_layers
# Freeze embedding layer + first N transformer layers
for param in seq_embeddings.parameters():
param.requires_grad = False
for i in range(min(freeze_layers, len(seq_layers))):
for param in seq_layers[i].parameters():
param.requires_grad = False
trainable = sum(
p.numel() for p in seq_embeddings.parameters() if p.requires_grad
)
trainable += sum(
p.numel()
for layer in seq_layers
for p in layer.parameters()
if p.requires_grad
)
total = sum(p.numel() for p in seq_embeddings.parameters())
total += sum(
p.numel() for layer in seq_layers for p in layer.parameters()
)
print(
f" Bertose encoder: {total:,} params total, "
f"{trainable:,} trainable (frozen layers 0-{freeze_layers - 1})"
)
return config, seq_embeddings, seq_layers
# ============================================================================
# Differentiable Interpolation (from V3, Sessions 5-6)
# ============================================================================
def differentiable_interp1d(
x: torch.Tensor, y: torch.Tensor, xnew: torch.Tensor
) -> torch.Tensor:
"""
Fully differentiable 1D linear interpolation.
Gradients flow through y (values) back to theta projection and
earlier layers. The original Interp1d.backward from Twin Peaks
only returned gradients for xnew (query coords), NOT for y —
killing 83% of gradient flow. Fixed in Session 5.
Args:
x: [B, N] sorted input coordinates (detached)
y: [B, N] values at x positions (REQUIRES grad flow!)
xnew: [B, R] query coordinates
Returns:
[B, R] interpolated values
"""
n_pts = x.shape[1]
# Find interpolation indices
ind = torch.searchsorted(
x.contiguous().detach(), xnew.contiguous().detach()
)
ind = ind.clamp(1, n_pts - 1)
# Gather neighbor values — preserves gradient flow through y
x_lo = torch.gather(x, 1, ind - 1)
x_hi = torch.gather(x, 1, ind)
y_lo = torch.gather(y, 1, ind - 1)
y_hi = torch.gather(y, 1, ind)
# Linear interpolation weights
denom = (x_hi - x_lo).clamp(min=1e-8)
alpha = ((xnew - x_lo) / denom).clamp(0, 1)
# Interpolated value — fully differentiable w.r.t. y_lo and y_hi
return y_lo + alpha * (y_hi - y_lo)
# ============================================================================
# SWE Pooling (from V3, mask-aware fixes from Sessions 6-8)
# ============================================================================
class SWE_Pooling(nn.Module):
"""
Sliced-Wasserstein Embedding pooling.
Maps variable-length token embeddings [B, L, d_in] => [B, num_slices].
From Twin Peaks, with mask-aware sorting (Session 6) and
degenerate-sample handling (Session 8).
"""
def __init__(
self,
d_in: int,
num_slices: int,
num_ref_points: int,
freeze_swe: bool = False,
):
super().__init__()
self.num_slices = num_slices
self.num_ref_points = num_ref_points
# Learnable reference distribution
ref = torch.linspace(-1, 1, num_ref_points).unsqueeze(1).repeat(
1, num_slices
)
self.reference = nn.Parameter(ref, requires_grad=not freeze_swe)
# Projection directions (weight-normalized)
self.theta = nn.utils.weight_norm(
nn.Linear(d_in, num_slices, bias=False), dim=0
)
self.theta.weight_g.data = torch.ones_like(self.theta.weight_g.data)
self.theta.weight_g.requires_grad = False
nn.init.normal_(self.theta.weight_v)
# Weighted aggregation over reference points
self.weight = nn.Linear(num_ref_points, 1, bias=False)
if freeze_swe:
self.theta.weight_v.requires_grad = False
self.reference.requires_grad = False
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x: [B, L, d_in] token embeddings
mask: [B, L] attention mask (1=valid, 0=padding)
Returns:
[B, num_slices] fixed-length representation
"""
batch_size, seq_len, _ = x.shape
device = x.device
# Degenerate: single token → just project
if seq_len == 1:
x_slices = self.theta(x)
return x_slices.squeeze(1)
# Project onto learned directions
x_slices = self.theta(x) # [B, L, num_slices]
# MASK-AWARE SORTING: set padding → -inf so they sort to bottom
if mask is not None:
mask_exp = mask.unsqueeze(-1).expand_as(x_slices)
x_slices = x_slices.masked_fill(mask_exp == 0, float("-inf"))
x_sorted, _ = torch.sort(x_slices, dim=1)
# Strip padding from sorted array
if mask is not None:
valid_counts = mask.sum(dim=1).long()
degenerate_mask = valid_counts < 2
safe_counts = valid_counts.clamp(min=2)
max_valid = safe_counts.max().item()
# Vectorized gather: take last safe_count values
starts = (seq_len - safe_counts).unsqueeze(1)
offsets = torch.arange(max_valid, device=device).unsqueeze(0)
raw_idx = starts + offsets
gather_idx = raw_idx.clamp(max=seq_len - 1).unsqueeze(-1).expand(
batch_size, max_valid, self.num_slices
)
x_sorted = torch.gather(x_sorted, 1, gather_idx)
# Replace -inf with 0.0 for degenerate samples
x_sorted = x_sorted.masked_fill(
x_sorted == float("-inf"), 0.0
)
n_eff = max_valid
else:
degenerate_mask = None
n_eff = seq_len
# Interpolate to fixed reference grid
x_coord = (
torch.linspace(0, 1, n_eff, device=device)
.unsqueeze(0)
.expand(batch_size * self.num_slices, -1)
)
x_flat = x_sorted.permute(0, 2, 1).reshape(
batch_size * self.num_slices, n_eff
)
xnew = (
torch.linspace(0, 1, self.num_ref_points, device=device)
.unsqueeze(0)
.expand(batch_size * self.num_slices, -1)
)
y_intp = differentiable_interp1d(x_coord, x_flat, xnew)
x_interp = y_intp.view(
batch_size, self.num_slices, self.num_ref_points
).permute(0, 2, 1)
# Compare with reference distribution
r_expanded = self.reference.expand_as(x_interp)
embeddings = (r_expanded - x_interp).permute(0, 2, 1)
# Weighted aggregation → [B, num_slices]
weighted = self.weight(embeddings).sum(dim=-1)
# Zero out degenerate samples
if degenerate_mask is not None and degenerate_mask.any():
weighted = weighted.masked_fill(
degenerate_mask.unsqueeze(-1), 0.0
)
return weighted
# ============================================================================
# Cross-Attention Block (from V3, NaN guard from Session 9)
# ============================================================================
class CrossAttentionBlock(nn.Module):
"""
Bidirectional cross-attention between glycan and protein tokens.
Glycan tokens attend to protein residues (Q=glycan, KV=protein)
Protein residues attend to glycan tokens (Q=protein, KV=glycan)
Includes NaN guard for all-masked keys (Session 9 fix) and
padding-position zeroing (Session 7 fix).
"""
def __init__(
self,
d_model: int,
num_heads: int,
ffn_dim: int,
dropout: float = 0.1,
):
super().__init__()
# Glycan → Protein cross-attention
self.glycan_cross_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.glycan_norm1 = nn.LayerNorm(d_model)
self.glycan_ffn = nn.Sequential(
nn.Linear(d_model, ffn_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ffn_dim, d_model),
nn.Dropout(dropout),
)
self.glycan_norm2 = nn.LayerNorm(d_model)
# Protein → Glycan cross-attention
self.protein_cross_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.protein_norm1 = nn.LayerNorm(d_model)
self.protein_ffn = nn.Sequential(
nn.Linear(d_model, ffn_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ffn_dim, d_model),
nn.Dropout(dropout),
)
self.protein_norm2 = nn.LayerNorm(d_model)
def forward(
self,
glycan: torch.Tensor,
protein: torch.Tensor,
glycan_mask: Optional[torch.Tensor] = None,
protein_mask: Optional[torch.Tensor] = None,
return_attention: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns updated (glycan, protein) enriched with cross-modal info.
NaN guard: nn.MultiheadAttention produces NaN when ALL keys
are masked. We replace NaN→0 so residual preserves query.
"""
# Convert to key_padding_mask (True = padded)
g_key_pad = (~glycan_mask.bool()) if glycan_mask is not None else None
p_key_pad = (
(~protein_mask.bool()) if protein_mask is not None else None
)
# Glycan attends to protein
g_cross, g_attn_weights = self.glycan_cross_attn(
query=glycan, key=protein, value=protein,
key_padding_mask=p_key_pad,
need_weights=return_attention,
average_attn_weights=False,
)
g_cross = torch.nan_to_num(g_cross, nan=0.0)
glycan = self.glycan_norm1(glycan + g_cross)
glycan = self.glycan_norm2(glycan + self.glycan_ffn(glycan))
if glycan_mask is not None:
glycan = glycan * glycan_mask.unsqueeze(-1)
# Protein attends to glycan
p_cross, p_attn_weights = self.protein_cross_attn(
query=protein, key=glycan, value=glycan,
key_padding_mask=g_key_pad,
need_weights=return_attention,
average_attn_weights=False,
)
p_cross = torch.nan_to_num(p_cross, nan=0.0)
protein = self.protein_norm1(protein + p_cross)
protein = self.protein_norm2(protein + self.protein_ffn(protein))
if protein_mask is not None:
protein = protein * protein_mask.unsqueeze(-1)
if return_attention:
attn_dict = {
"glycan_to_protein": g_attn_weights,
"protein_to_glycan": p_attn_weights,
}
return glycan, protein, attn_dict
return glycan, protein
# ============================================================================
# Bertint V8 Model
# ============================================================================
class BertintV8(nn.Module):
"""
Glycan-protein interaction predictor with cross-attention.
Glycan: Live Bertose (partially frozen) → per-token [B, Lg, 768]
Protein: Precomputed ESM-C per-residue [B, Lp, 960]
Cross-attention: 2 bidirectional layers in shared 512-dim space
SWE: Variable-length → fixed [B, 512] for each side
Interaction: product + sum → MLP → scalar
"""
def __init__(
self,
seq_embeddings: nn.Module,
seq_layers: nn.ModuleList,
glycan_dim: int = 768,
protein_dim: int = 960,
shared_dim: int = 512,
num_cross_layers: int = 2,
num_heads: int = 8,
ffn_dim: int = 1024,
swe_slices: int = 512,
swe_ref_points: int = 64,
head_hidden: int = 256,
dropout: float = 0.1,
separate_swe: bool = False,
pooling_mode: str = "swe",
interaction_mode: str = "product_sum",
use_cross_attention: bool = True,
):
"""
Args:
seq_embeddings: Pretrained Bertose embedding layer.
seq_layers: Pretrained Bertose transformer layers.
glycan_dim: Bertose output dimension (768).
protein_dim: ESM-C per-residue dimension (960).
shared_dim: Shared space for cross-attention (512).
num_cross_layers: Number of cross-attention blocks.
num_heads: Attention heads per block.
ffn_dim: FFN hidden dim in cross-attention.
swe_slices: Number of SWE projection directions.
swe_ref_points: Number of SWE reference distribution points.
head_hidden: MLP head hidden dimension.
dropout: Dropout rate.
separate_swe: If True, use independent SWE modules.
pooling_mode: 'swe', 'mean', or 'joint_swe'.
interaction_mode: 'product_sum' or 'concat'.
use_cross_attention: If False, skip cross-attention.
"""
super().__init__()
self.separate_swe = separate_swe
self.pooling_mode = pooling_mode
self.interaction_mode = interaction_mode
self.use_cross_attention = use_cross_attention
print(f" Architecture config:")
print(f" cross_attention={use_cross_attention}")
print(f" pooling_mode={pooling_mode}")
print(f" interaction_mode={interaction_mode}")
# === Bertose sequence encoder (partially frozen) ===
self.seq_embeddings = seq_embeddings
self.seq_layers = seq_layers
# === Projection to shared space ===
self.glycan_proj = nn.Sequential(
nn.Linear(glycan_dim, shared_dim),
nn.LayerNorm(shared_dim),
)
self.protein_proj = nn.Sequential(
nn.Linear(protein_dim, shared_dim),
nn.LayerNorm(shared_dim),
)
# === Cross-attention stack (optional) ===
if use_cross_attention:
self.cross_attention = nn.ModuleList([
CrossAttentionBlock(
d_model=shared_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
dropout=dropout,
)
for _ in range(num_cross_layers)
])
else:
self.cross_attention = nn.ModuleList()
# === Pooling ===
if pooling_mode == "swe":
if separate_swe:
self.swe_glycan = SWE_Pooling(
d_in=shared_dim, num_slices=swe_slices,
num_ref_points=swe_ref_points,
)
self.swe_protein = SWE_Pooling(
d_in=shared_dim, num_slices=swe_slices,
num_ref_points=swe_ref_points,
)
pool_out_dim = swe_slices
else:
self.swe = SWE_Pooling(
d_in=shared_dim, num_slices=swe_slices,
num_ref_points=swe_ref_points,
)
pool_out_dim = swe_slices
elif pooling_mode == "mean":
pool_out_dim = shared_dim
elif pooling_mode == "joint_swe":
self.swe_joint = SWE_Pooling(
d_in=shared_dim, num_slices=swe_slices,
num_ref_points=swe_ref_points,
)
pool_out_dim = swe_slices
# === Regression head ===
if pooling_mode == "joint_swe":
head_input_dim = pool_out_dim
else:
head_input_dim = 2 * pool_out_dim
self.head = nn.Sequential(
nn.Linear(head_input_dim, head_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(head_hidden, head_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(head_hidden, 1),
)
# Initialize (skip SWE weight-normed params)
self.apply(self._init_weights)
self._count_params()
def _init_weights(self, module: nn.Module) -> None:
"""Xavier init for Linear, skip weight-normed SWE modules."""
if isinstance(module, nn.Linear):
if hasattr(module, "weight_v"):
return # Preserve SWE initialization
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def _count_params(self) -> None:
"""Log parameter counts."""
total = sum(p.numel() for p in self.parameters())
trainable = sum(
p.numel() for p in self.parameters() if p.requires_grad
)
print(f"BertintV8: {total:,} total, {trainable:,} trainable")
def _masked_mean_pool(self, x, mask):
"""Masked mean pooling: average valid tokens only."""
mask_expanded = mask.unsqueeze(-1)
x_masked = x * mask_expanded
summed = x_masked.sum(dim=1)
counts = mask.sum(dim=1, keepdim=True).clamp(min=1)
return summed / counts
def forward(
self,
token_ids: torch.Tensor,
attention_mask: torch.Tensor,
branch_depths: torch.Tensor,
linkage_types: torch.Tensor,
protein_emb: torch.Tensor,
protein_mask: torch.Tensor,
log_conc: Optional[torch.Tensor] = None,
has_conc: Optional[torch.Tensor] = None,
return_attention: bool = False,
) -> torch.Tensor:
"""
Forward pass with cross-attention.
Args:
token_ids: [B, Lg] BPE token IDs.
attention_mask: [B, Lg] glycan attention mask (1=valid, 0=pad).
branch_depths: [B, Lg] branch depth per token.
linkage_types: [B, Lg] linkage type per token.
protein_emb: [B, Lp, protein_dim] per-residue ESM-C embeddings.
protein_mask: [B, Lp] protein attention mask (1=valid, 0=pad).
Returns:
[B] binding score predictions.
"""
# === 1. Bertose forward: per-token embeddings ===
x = self.seq_embeddings(token_ids, branch_depths, linkage_types)
for layer in self.seq_layers:
x = layer(x, attention_mask)
# x: [B, Lg, 768] — all tokens (not just CLS!)
# Glycan mask: use the attention_mask from BPE tokenizer
glycan_mask = attention_mask # [B, Lg], 1=valid, 0=pad
# === 2. Project to shared space ===
glycan = self.glycan_proj(x) # [B, Lg, 512]
protein = self.protein_proj(protein_emb) # [B, Lp, 512]
# Zero padding positions after projection (Session 7 fix:
# LayerNorm bias produces non-trivial values at padding)
glycan = glycan * glycan_mask.unsqueeze(-1)
protein = protein * protein_mask.unsqueeze(-1)
# === 3. Cross-attention (optional) ===
all_attention_maps = []
if self.use_cross_attention:
for cross_layer in self.cross_attention:
if return_attention:
glycan, protein, attn_dict = cross_layer(
glycan, protein, glycan_mask, protein_mask,
return_attention=True,
)
all_attention_maps.append(attn_dict)
else:
glycan, protein = cross_layer(
glycan, protein, glycan_mask, protein_mask
)
# === 4. Pooling ===
if self.pooling_mode == "joint_swe":
joint_tokens = torch.cat([glycan, protein], dim=1)
joint_mask = torch.cat([glycan_mask, protein_mask], dim=1)
pooled = self.swe_joint(joint_tokens, joint_mask)
return self.head(pooled).squeeze(-1)
if self.pooling_mode == "swe":
if self.separate_swe:
glycan_pooled = self.swe_glycan(glycan, glycan_mask)
protein_pooled = self.swe_protein(protein, protein_mask)
else:
glycan_pooled = self.swe(glycan, glycan_mask)
protein_pooled = self.swe(protein, protein_mask)
elif self.pooling_mode == "mean":
glycan_pooled = self._masked_mean_pool(glycan, glycan_mask)
protein_pooled = self._masked_mean_pool(protein, protein_mask)
# === 5. Interaction ===
if self.interaction_mode == "product_sum":
interaction = torch.cat([
glycan_pooled * protein_pooled,
glycan_pooled + protein_pooled,
], dim=-1)
elif self.interaction_mode == "concat":
interaction = torch.cat([
glycan_pooled, protein_pooled,
], dim=-1)
# === 6. Predict binding score ===
out = self.head(interaction).squeeze(-1)
if return_attention and all_attention_maps:
return out, all_attention_maps
return out
# ============================================================================
# Loss (same as V7 — simple MSE)
# ============================================================================
class BertintV8Loss(nn.Module):
"""MSE loss for regression."""
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(
self, pred: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""Compute MSE loss."""
return self.mse(pred, target)
# ============================================================================
# Sanity check
# ============================================================================
if __name__ == "__main__":
print("=" * 60)
print("BertintV8 Architecture Sanity Check")
print("=" * 60)
# Mock Bertose encoder (for testing without cluster)
class MockEmbeddings(nn.Module):
"""Mock Bertose embeddings for local testing."""
def __init__(self, dim: int = 768):
super().__init__()
self.proj = nn.Linear(64, dim)
def forward(self, token_ids, branch_depths, linkage_types):
"""Return random embeddings."""
batch_size, seq_len = token_ids.shape
return torch.randn(batch_size, seq_len, 768)
class MockLayer(nn.Module):
"""Mock transformer layer."""
def forward(self, x, mask):
"""Identity."""
return x
seq_emb = MockEmbeddings()
seq_layers = nn.ModuleList([MockLayer() for _ in range(12)])
model = BertintV8(
seq_embeddings=seq_emb,
seq_layers=seq_layers,
glycan_dim=768,
protein_dim=960,
)
# Simulate batch
batch_size = 4
lg = 37 # Glycan: 37 BPE tokens
lp = 150 # Protein: 150 residues
token_ids = torch.randint(0, 100, (batch_size, lg))
attention_mask = torch.ones(batch_size, lg).float()
branch_depths = torch.zeros(batch_size, lg, dtype=torch.long)
linkage_types = torch.zeros(batch_size, lg, dtype=torch.long)
protein_emb = torch.randn(batch_size, lp, 960)
protein_mask = torch.ones(batch_size, lp).float()
out = model(
token_ids=token_ids,
attention_mask=attention_mask,
branch_depths=branch_depths,
linkage_types=linkage_types,
protein_emb=protein_emb,
protein_mask=protein_mask,
)
print(f"\nInput shapes:")
print(f" Glycan tokens: {token_ids.shape}")
print(f" Protein emb: {protein_emb.shape}")
print(f"\nOutput shape: {out.shape} — values: {out.detach()}")
# Test loss
loss_fn = BertintV8Loss()
target = torch.rand(batch_size)
loss = loss_fn(out, target)
print(f"\nLoss: {loss.item():.6f}")
# Test backward
loss.backward()
grad_count = sum(
1 for p in model.parameters() if p.grad is not None and p.requires_grad
)
total_trainable = sum(
1 for p in model.parameters() if p.requires_grad
)
print(f"Gradients: {grad_count}/{total_trainable} trainable params")
print(f"\n✅ V8 sanity check passed!")
"""Bertint V8 model — cross-attention + live Bertose finetuning."""