bertose-affinose-training-code / code /model /enhanced_classifier.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
6.42 kB
"""
Enhanced Glycan Classifier with Architecture Improvements
Uses the new architecture components:
- #1 MonosaccharidePooling: Pool tokens to residue level
- #2 ResidueTypeEmbeddings: Add monosaccharide type embeddings
- #4 RelativePositionBias: Tree-aware position encoding
"""
import torch
import torch.nn as nn
from typing import Optional, Dict
try:
from .multimodal_glycan_bert_v3 import (
MultimodalGlycanBERT,
MultimodalGlycanBERTConfig,
MonosaccharidePooling,
ResidueTypeEmbeddings,
RelativePositionBias,
MONOSACCHARIDE_VOCAB,
)
except ImportError:
from multimodal_glycan_bert_v3 import (
MultimodalGlycanBERT,
MultimodalGlycanBERTConfig,
MonosaccharidePooling,
ResidueTypeEmbeddings,
RelativePositionBias,
MONOSACCHARIDE_VOCAB,
)
class EnhancedGlycanClassifier(nn.Module):
"""
Classification head using architecture improvements #1-4.
Key differences from basic classifier:
1. Monosaccharide-level pooling (not first-token or mean)
2. Optional residue type embeddings
3. Optional relative position bias (requires model modification)
"""
def __init__(
self,
bert: MultimodalGlycanBERT,
num_classes: int,
dropout: float = 0.1,
freeze_layers: int = 4,
use_mono_pooling: bool = True,
use_residue_types: bool = True,
):
super().__init__()
self.bert = bert
self.num_classes = num_classes
self.use_mono_pooling = use_mono_pooling
self.use_residue_types = use_residue_types
hidden_size = bert.config.seq_hidden_size
# Freeze bottom layers
for i, layer in enumerate(self.bert.seq_layers):
if i < freeze_layers:
for param in layer.parameters():
param.requires_grad = False
# #1: Monosaccharide-level pooling
if use_mono_pooling:
self.mono_pooling = MonosaccharidePooling(
hidden_size=hidden_size,
num_attention_heads=8,
dropout=dropout
)
# #2: Residue type embeddings
if use_residue_types:
self.residue_embeddings = ResidueTypeEmbeddings(
hidden_size=hidden_size,
num_mono_types=len(MONOSACCHARIDE_VOCAB) + 10 # Buffer for new types
)
# Classification head
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_size // 2, num_classes),
)
def forward(
self,
token_ids: torch.Tensor, # (batch, seq_len)
attention_mask: torch.Tensor, # (batch, seq_len)
residue_ids: torch.Tensor = None, # (batch, seq_len) - which residue each token belongs to
mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue
) -> torch.Tensor:
"""
Forward pass with architecture improvements.
Args:
token_ids: Token IDs
attention_mask: Attention mask
residue_ids: Residue ID for each token (from data)
mono_type_ids: Monosaccharide type ID for each residue (from data)
Returns:
logits: (batch, num_classes)
"""
# Get sequence embeddings
seq_hidden = self.bert.seq_embeddings(token_ids)
# #2: Add residue type embeddings if available
if self.use_residue_types and residue_ids is not None:
seq_hidden = self.residue_embeddings(
seq_hidden, residue_ids, mono_type_ids
)
# Apply transformer layers
for layer in self.bert.seq_layers:
seq_hidden = layer(seq_hidden, attention_mask)
# Pool to glycan representation
if self.use_mono_pooling and residue_ids is not None:
# #1: Monosaccharide-level pooling
pooled = self.mono_pooling(seq_hidden, residue_ids, attention_mask)
else:
# Fallback: Mean pooling
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_hidden = (seq_hidden * mask_expanded).sum(dim=1)
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
pooled = sum_hidden / sum_mask
# Classify
logits = self.classifier(pooled)
return logits
def prepare_mono_type_ids(mono_names_batch, max_residues: int = 50, device='cpu'):
"""
Convert batch of monosaccharide name lists to type ID tensor.
Args:
mono_names_batch: List of lists of monosaccharide names
max_residues: Maximum number of residues to pad to
device: Device for tensor
Returns:
mono_type_ids: (batch, max_residues) tensor
"""
batch_size = len(mono_names_batch)
mono_type_ids = torch.zeros(batch_size, max_residues, dtype=torch.long, device=device)
for b, mono_names in enumerate(mono_names_batch):
for i, name in enumerate(mono_names):
if i >= max_residues:
break
mono_type_ids[b, i] = ResidueTypeEmbeddings.get_mono_type_id(name)
return mono_type_ids
if __name__ == '__main__':
# Test the enhanced classifier
print("Testing EnhancedGlycanClassifier...")
config = MultimodalGlycanBERTConfig(use_cnn_frontend=True)
bert = MultimodalGlycanBERT(config)
classifier = EnhancedGlycanClassifier(
bert=bert,
num_classes=31, # species task
use_mono_pooling=True,
use_residue_types=True,
)
# Create dummy input
batch_size = 2
seq_len = 64
token_ids = torch.randint(0, 100, (batch_size, seq_len))
attention_mask = torch.ones(batch_size, seq_len)
residue_ids = torch.div(torch.arange(seq_len), 10, rounding_mode='floor').unsqueeze(0).expand(batch_size, -1)
mono_type_ids = torch.randint(0, 20, (batch_size, 10))
logits = classifier(token_ids, attention_mask, residue_ids, mono_type_ids)
print(f"✅ Output shape: {logits.shape}")
print(f"✅ Total params: {sum(p.numel() for p in classifier.parameters()):,}")