| """ |
| Enhanced Glycan Classifier with Architecture Improvements |
| |
| Uses the new architecture components: |
| - #1 MonosaccharidePooling: Pool tokens to residue level |
| - #2 ResidueTypeEmbeddings: Add monosaccharide type embeddings |
| - #4 RelativePositionBias: Tree-aware position encoding |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional, Dict |
|
|
| try: |
| from .multimodal_glycan_bert_v3 import ( |
| MultimodalGlycanBERT, |
| MultimodalGlycanBERTConfig, |
| MonosaccharidePooling, |
| ResidueTypeEmbeddings, |
| RelativePositionBias, |
| MONOSACCHARIDE_VOCAB, |
| ) |
| except ImportError: |
| from multimodal_glycan_bert_v3 import ( |
| MultimodalGlycanBERT, |
| MultimodalGlycanBERTConfig, |
| MonosaccharidePooling, |
| ResidueTypeEmbeddings, |
| RelativePositionBias, |
| MONOSACCHARIDE_VOCAB, |
| ) |
|
|
|
|
| class EnhancedGlycanClassifier(nn.Module): |
| """ |
| Classification head using architecture improvements #1-4. |
| |
| Key differences from basic classifier: |
| 1. Monosaccharide-level pooling (not first-token or mean) |
| 2. Optional residue type embeddings |
| 3. Optional relative position bias (requires model modification) |
| """ |
| |
| def __init__( |
| self, |
| bert: MultimodalGlycanBERT, |
| num_classes: int, |
| dropout: float = 0.1, |
| freeze_layers: int = 4, |
| use_mono_pooling: bool = True, |
| use_residue_types: bool = True, |
| ): |
| super().__init__() |
| self.bert = bert |
| self.num_classes = num_classes |
| self.use_mono_pooling = use_mono_pooling |
| self.use_residue_types = use_residue_types |
| |
| hidden_size = bert.config.seq_hidden_size |
| |
| |
| for i, layer in enumerate(self.bert.seq_layers): |
| if i < freeze_layers: |
| for param in layer.parameters(): |
| param.requires_grad = False |
| |
| |
| if use_mono_pooling: |
| self.mono_pooling = MonosaccharidePooling( |
| hidden_size=hidden_size, |
| num_attention_heads=8, |
| dropout=dropout |
| ) |
| |
| |
| if use_residue_types: |
| self.residue_embeddings = ResidueTypeEmbeddings( |
| hidden_size=hidden_size, |
| num_mono_types=len(MONOSACCHARIDE_VOCAB) + 10 |
| ) |
| |
| |
| self.classifier = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size, hidden_size // 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size // 2, num_classes), |
| ) |
| |
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| residue_ids: torch.Tensor = None, |
| mono_type_ids: torch.Tensor = None, |
| ) -> torch.Tensor: |
| """ |
| Forward pass with architecture improvements. |
| |
| Args: |
| token_ids: Token IDs |
| attention_mask: Attention mask |
| residue_ids: Residue ID for each token (from data) |
| mono_type_ids: Monosaccharide type ID for each residue (from data) |
| |
| Returns: |
| logits: (batch, num_classes) |
| """ |
| |
| seq_hidden = self.bert.seq_embeddings(token_ids) |
| |
| |
| if self.use_residue_types and residue_ids is not None: |
| seq_hidden = self.residue_embeddings( |
| seq_hidden, residue_ids, mono_type_ids |
| ) |
| |
| |
| for layer in self.bert.seq_layers: |
| seq_hidden = layer(seq_hidden, attention_mask) |
| |
| |
| if self.use_mono_pooling and residue_ids is not None: |
| |
| pooled = self.mono_pooling(seq_hidden, residue_ids, attention_mask) |
| else: |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| sum_hidden = (seq_hidden * mask_expanded).sum(dim=1) |
| sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9) |
| pooled = sum_hidden / sum_mask |
| |
| |
| logits = self.classifier(pooled) |
| |
| return logits |
|
|
|
|
| def prepare_mono_type_ids(mono_names_batch, max_residues: int = 50, device='cpu'): |
| """ |
| Convert batch of monosaccharide name lists to type ID tensor. |
| |
| Args: |
| mono_names_batch: List of lists of monosaccharide names |
| max_residues: Maximum number of residues to pad to |
| device: Device for tensor |
| |
| Returns: |
| mono_type_ids: (batch, max_residues) tensor |
| """ |
| batch_size = len(mono_names_batch) |
| mono_type_ids = torch.zeros(batch_size, max_residues, dtype=torch.long, device=device) |
| |
| for b, mono_names in enumerate(mono_names_batch): |
| for i, name in enumerate(mono_names): |
| if i >= max_residues: |
| break |
| mono_type_ids[b, i] = ResidueTypeEmbeddings.get_mono_type_id(name) |
| |
| return mono_type_ids |
|
|
|
|
| if __name__ == '__main__': |
| |
| print("Testing EnhancedGlycanClassifier...") |
| |
| config = MultimodalGlycanBERTConfig(use_cnn_frontend=True) |
| bert = MultimodalGlycanBERT(config) |
| |
| classifier = EnhancedGlycanClassifier( |
| bert=bert, |
| num_classes=31, |
| use_mono_pooling=True, |
| use_residue_types=True, |
| ) |
| |
| |
| batch_size = 2 |
| seq_len = 64 |
| token_ids = torch.randint(0, 100, (batch_size, seq_len)) |
| attention_mask = torch.ones(batch_size, seq_len) |
| residue_ids = torch.div(torch.arange(seq_len), 10, rounding_mode='floor').unsqueeze(0).expand(batch_size, -1) |
| mono_type_ids = torch.randint(0, 20, (batch_size, 10)) |
| |
| logits = classifier(token_ids, attention_mask, residue_ids, mono_type_ids) |
| |
| print(f"✅ Output shape: {logits.shape}") |
| print(f"✅ Total params: {sum(p.numel() for p in classifier.parameters()):,}") |
|
|