""" Enhanced Glycan Classifier with Architecture Improvements Uses the new architecture components: - #1 MonosaccharidePooling: Pool tokens to residue level - #2 ResidueTypeEmbeddings: Add monosaccharide type embeddings - #4 RelativePositionBias: Tree-aware position encoding """ import torch import torch.nn as nn from typing import Optional, Dict try: from .multimodal_glycan_bert_v3 import ( MultimodalGlycanBERT, MultimodalGlycanBERTConfig, MonosaccharidePooling, ResidueTypeEmbeddings, RelativePositionBias, MONOSACCHARIDE_VOCAB, ) except ImportError: from multimodal_glycan_bert_v3 import ( MultimodalGlycanBERT, MultimodalGlycanBERTConfig, MonosaccharidePooling, ResidueTypeEmbeddings, RelativePositionBias, MONOSACCHARIDE_VOCAB, ) class EnhancedGlycanClassifier(nn.Module): """ Classification head using architecture improvements #1-4. Key differences from basic classifier: 1. Monosaccharide-level pooling (not first-token or mean) 2. Optional residue type embeddings 3. Optional relative position bias (requires model modification) """ def __init__( self, bert: MultimodalGlycanBERT, num_classes: int, dropout: float = 0.1, freeze_layers: int = 4, use_mono_pooling: bool = True, use_residue_types: bool = True, ): super().__init__() self.bert = bert self.num_classes = num_classes self.use_mono_pooling = use_mono_pooling self.use_residue_types = use_residue_types hidden_size = bert.config.seq_hidden_size # Freeze bottom layers for i, layer in enumerate(self.bert.seq_layers): if i < freeze_layers: for param in layer.parameters(): param.requires_grad = False # #1: Monosaccharide-level pooling if use_mono_pooling: self.mono_pooling = MonosaccharidePooling( hidden_size=hidden_size, num_attention_heads=8, dropout=dropout ) # #2: Residue type embeddings if use_residue_types: self.residue_embeddings = ResidueTypeEmbeddings( hidden_size=hidden_size, num_mono_types=len(MONOSACCHARIDE_VOCAB) + 10 # Buffer for new types ) # Classification head self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_size, hidden_size // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_size // 2, num_classes), ) def forward( self, token_ids: torch.Tensor, # (batch, seq_len) attention_mask: torch.Tensor, # (batch, seq_len) residue_ids: torch.Tensor = None, # (batch, seq_len) - which residue each token belongs to mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue ) -> torch.Tensor: """ Forward pass with architecture improvements. Args: token_ids: Token IDs attention_mask: Attention mask residue_ids: Residue ID for each token (from data) mono_type_ids: Monosaccharide type ID for each residue (from data) Returns: logits: (batch, num_classes) """ # Get sequence embeddings seq_hidden = self.bert.seq_embeddings(token_ids) # #2: Add residue type embeddings if available if self.use_residue_types and residue_ids is not None: seq_hidden = self.residue_embeddings( seq_hidden, residue_ids, mono_type_ids ) # Apply transformer layers for layer in self.bert.seq_layers: seq_hidden = layer(seq_hidden, attention_mask) # Pool to glycan representation if self.use_mono_pooling and residue_ids is not None: # #1: Monosaccharide-level pooling pooled = self.mono_pooling(seq_hidden, residue_ids, attention_mask) else: # Fallback: Mean pooling mask_expanded = attention_mask.unsqueeze(-1).float() sum_hidden = (seq_hidden * mask_expanded).sum(dim=1) sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9) pooled = sum_hidden / sum_mask # Classify logits = self.classifier(pooled) return logits def prepare_mono_type_ids(mono_names_batch, max_residues: int = 50, device='cpu'): """ Convert batch of monosaccharide name lists to type ID tensor. Args: mono_names_batch: List of lists of monosaccharide names max_residues: Maximum number of residues to pad to device: Device for tensor Returns: mono_type_ids: (batch, max_residues) tensor """ batch_size = len(mono_names_batch) mono_type_ids = torch.zeros(batch_size, max_residues, dtype=torch.long, device=device) for b, mono_names in enumerate(mono_names_batch): for i, name in enumerate(mono_names): if i >= max_residues: break mono_type_ids[b, i] = ResidueTypeEmbeddings.get_mono_type_id(name) return mono_type_ids if __name__ == '__main__': # Test the enhanced classifier print("Testing EnhancedGlycanClassifier...") config = MultimodalGlycanBERTConfig(use_cnn_frontend=True) bert = MultimodalGlycanBERT(config) classifier = EnhancedGlycanClassifier( bert=bert, num_classes=31, # species task use_mono_pooling=True, use_residue_types=True, ) # Create dummy input batch_size = 2 seq_len = 64 token_ids = torch.randint(0, 100, (batch_size, seq_len)) attention_mask = torch.ones(batch_size, seq_len) residue_ids = torch.div(torch.arange(seq_len), 10, rounding_mode='floor').unsqueeze(0).expand(batch_size, -1) mono_type_ids = torch.randint(0, 20, (batch_size, 10)) logits = classifier(token_ids, attention_mask, residue_ids, mono_type_ids) print(f"✅ Output shape: {logits.shape}") print(f"✅ Total params: {sum(p.numel() for p in classifier.parameters()):,}")