File size: 6,419 Bytes
1d6f391 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """
Enhanced Glycan Classifier with Architecture Improvements
Uses the new architecture components:
- #1 MonosaccharidePooling: Pool tokens to residue level
- #2 ResidueTypeEmbeddings: Add monosaccharide type embeddings
- #4 RelativePositionBias: Tree-aware position encoding
"""
import torch
import torch.nn as nn
from typing import Optional, Dict
try:
from .multimodal_glycan_bert_v3 import (
MultimodalGlycanBERT,
MultimodalGlycanBERTConfig,
MonosaccharidePooling,
ResidueTypeEmbeddings,
RelativePositionBias,
MONOSACCHARIDE_VOCAB,
)
except ImportError:
from multimodal_glycan_bert_v3 import (
MultimodalGlycanBERT,
MultimodalGlycanBERTConfig,
MonosaccharidePooling,
ResidueTypeEmbeddings,
RelativePositionBias,
MONOSACCHARIDE_VOCAB,
)
class EnhancedGlycanClassifier(nn.Module):
"""
Classification head using architecture improvements #1-4.
Key differences from basic classifier:
1. Monosaccharide-level pooling (not first-token or mean)
2. Optional residue type embeddings
3. Optional relative position bias (requires model modification)
"""
def __init__(
self,
bert: MultimodalGlycanBERT,
num_classes: int,
dropout: float = 0.1,
freeze_layers: int = 4,
use_mono_pooling: bool = True,
use_residue_types: bool = True,
):
super().__init__()
self.bert = bert
self.num_classes = num_classes
self.use_mono_pooling = use_mono_pooling
self.use_residue_types = use_residue_types
hidden_size = bert.config.seq_hidden_size
# Freeze bottom layers
for i, layer in enumerate(self.bert.seq_layers):
if i < freeze_layers:
for param in layer.parameters():
param.requires_grad = False
# #1: Monosaccharide-level pooling
if use_mono_pooling:
self.mono_pooling = MonosaccharidePooling(
hidden_size=hidden_size,
num_attention_heads=8,
dropout=dropout
)
# #2: Residue type embeddings
if use_residue_types:
self.residue_embeddings = ResidueTypeEmbeddings(
hidden_size=hidden_size,
num_mono_types=len(MONOSACCHARIDE_VOCAB) + 10 # Buffer for new types
)
# Classification head
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_size // 2, num_classes),
)
def forward(
self,
token_ids: torch.Tensor, # (batch, seq_len)
attention_mask: torch.Tensor, # (batch, seq_len)
residue_ids: torch.Tensor = None, # (batch, seq_len) - which residue each token belongs to
mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue
) -> torch.Tensor:
"""
Forward pass with architecture improvements.
Args:
token_ids: Token IDs
attention_mask: Attention mask
residue_ids: Residue ID for each token (from data)
mono_type_ids: Monosaccharide type ID for each residue (from data)
Returns:
logits: (batch, num_classes)
"""
# Get sequence embeddings
seq_hidden = self.bert.seq_embeddings(token_ids)
# #2: Add residue type embeddings if available
if self.use_residue_types and residue_ids is not None:
seq_hidden = self.residue_embeddings(
seq_hidden, residue_ids, mono_type_ids
)
# Apply transformer layers
for layer in self.bert.seq_layers:
seq_hidden = layer(seq_hidden, attention_mask)
# Pool to glycan representation
if self.use_mono_pooling and residue_ids is not None:
# #1: Monosaccharide-level pooling
pooled = self.mono_pooling(seq_hidden, residue_ids, attention_mask)
else:
# Fallback: Mean pooling
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_hidden = (seq_hidden * mask_expanded).sum(dim=1)
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
pooled = sum_hidden / sum_mask
# Classify
logits = self.classifier(pooled)
return logits
def prepare_mono_type_ids(mono_names_batch, max_residues: int = 50, device='cpu'):
"""
Convert batch of monosaccharide name lists to type ID tensor.
Args:
mono_names_batch: List of lists of monosaccharide names
max_residues: Maximum number of residues to pad to
device: Device for tensor
Returns:
mono_type_ids: (batch, max_residues) tensor
"""
batch_size = len(mono_names_batch)
mono_type_ids = torch.zeros(batch_size, max_residues, dtype=torch.long, device=device)
for b, mono_names in enumerate(mono_names_batch):
for i, name in enumerate(mono_names):
if i >= max_residues:
break
mono_type_ids[b, i] = ResidueTypeEmbeddings.get_mono_type_id(name)
return mono_type_ids
if __name__ == '__main__':
# Test the enhanced classifier
print("Testing EnhancedGlycanClassifier...")
config = MultimodalGlycanBERTConfig(use_cnn_frontend=True)
bert = MultimodalGlycanBERT(config)
classifier = EnhancedGlycanClassifier(
bert=bert,
num_classes=31, # species task
use_mono_pooling=True,
use_residue_types=True,
)
# Create dummy input
batch_size = 2
seq_len = 64
token_ids = torch.randint(0, 100, (batch_size, seq_len))
attention_mask = torch.ones(batch_size, seq_len)
residue_ids = torch.div(torch.arange(seq_len), 10, rounding_mode='floor').unsqueeze(0).expand(batch_size, -1)
mono_type_ids = torch.randint(0, 20, (batch_size, 10))
logits = classifier(token_ids, attention_mask, residue_ids, mono_type_ids)
print(f"✅ Output shape: {logits.shape}")
print(f"✅ Total params: {sum(p.numel() for p in classifier.parameters()):,}")
|