File size: 6,419 Bytes
1d6f391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Enhanced Glycan Classifier with Architecture Improvements

Uses the new architecture components:
- #1 MonosaccharidePooling: Pool tokens to residue level
- #2 ResidueTypeEmbeddings: Add monosaccharide type embeddings
- #4 RelativePositionBias: Tree-aware position encoding
"""

import torch
import torch.nn as nn
from typing import Optional, Dict

try:
    from .multimodal_glycan_bert_v3 import (
        MultimodalGlycanBERT, 
        MultimodalGlycanBERTConfig,
        MonosaccharidePooling,
        ResidueTypeEmbeddings,
        RelativePositionBias,
        MONOSACCHARIDE_VOCAB,
    )
except ImportError:
    from multimodal_glycan_bert_v3 import (
        MultimodalGlycanBERT, 
        MultimodalGlycanBERTConfig,
        MonosaccharidePooling,
        ResidueTypeEmbeddings,
        RelativePositionBias,
        MONOSACCHARIDE_VOCAB,
    )


class EnhancedGlycanClassifier(nn.Module):
    """
    Classification head using architecture improvements #1-4.
    
    Key differences from basic classifier:
    1. Monosaccharide-level pooling (not first-token or mean)
    2. Optional residue type embeddings
    3. Optional relative position bias (requires model modification)
    """
    
    def __init__(
        self,
        bert: MultimodalGlycanBERT,
        num_classes: int,
        dropout: float = 0.1,
        freeze_layers: int = 4,
        use_mono_pooling: bool = True,
        use_residue_types: bool = True,
    ):
        super().__init__()
        self.bert = bert
        self.num_classes = num_classes
        self.use_mono_pooling = use_mono_pooling
        self.use_residue_types = use_residue_types
        
        hidden_size = bert.config.seq_hidden_size
        
        # Freeze bottom layers
        for i, layer in enumerate(self.bert.seq_layers):
            if i < freeze_layers:
                for param in layer.parameters():
                    param.requires_grad = False
        
        # #1: Monosaccharide-level pooling
        if use_mono_pooling:
            self.mono_pooling = MonosaccharidePooling(
                hidden_size=hidden_size,
                num_attention_heads=8,
                dropout=dropout
            )
        
        # #2: Residue type embeddings
        if use_residue_types:
            self.residue_embeddings = ResidueTypeEmbeddings(
                hidden_size=hidden_size,
                num_mono_types=len(MONOSACCHARIDE_VOCAB) + 10  # Buffer for new types
            )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, num_classes),
        )
    
    def forward(
        self,
        token_ids: torch.Tensor,           # (batch, seq_len)
        attention_mask: torch.Tensor,       # (batch, seq_len)
        residue_ids: torch.Tensor = None,   # (batch, seq_len) - which residue each token belongs to
        mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue
    ) -> torch.Tensor:
        """
        Forward pass with architecture improvements.
        
        Args:
            token_ids: Token IDs
            attention_mask: Attention mask
            residue_ids: Residue ID for each token (from data)
            mono_type_ids: Monosaccharide type ID for each residue (from data)
            
        Returns:
            logits: (batch, num_classes)
        """
        # Get sequence embeddings
        seq_hidden = self.bert.seq_embeddings(token_ids)
        
        # #2: Add residue type embeddings if available
        if self.use_residue_types and residue_ids is not None:
            seq_hidden = self.residue_embeddings(
                seq_hidden, residue_ids, mono_type_ids
            )
        
        # Apply transformer layers
        for layer in self.bert.seq_layers:
            seq_hidden = layer(seq_hidden, attention_mask)
        
        # Pool to glycan representation
        if self.use_mono_pooling and residue_ids is not None:
            # #1: Monosaccharide-level pooling
            pooled = self.mono_pooling(seq_hidden, residue_ids, attention_mask)
        else:
            # Fallback: Mean pooling
            mask_expanded = attention_mask.unsqueeze(-1).float()
            sum_hidden = (seq_hidden * mask_expanded).sum(dim=1)
            sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
            pooled = sum_hidden / sum_mask
        
        # Classify
        logits = self.classifier(pooled)
        
        return logits


def prepare_mono_type_ids(mono_names_batch, max_residues: int = 50, device='cpu'):
    """
    Convert batch of monosaccharide name lists to type ID tensor.
    
    Args:
        mono_names_batch: List of lists of monosaccharide names
        max_residues: Maximum number of residues to pad to
        device: Device for tensor
        
    Returns:
        mono_type_ids: (batch, max_residues) tensor
    """
    batch_size = len(mono_names_batch)
    mono_type_ids = torch.zeros(batch_size, max_residues, dtype=torch.long, device=device)
    
    for b, mono_names in enumerate(mono_names_batch):
        for i, name in enumerate(mono_names):
            if i >= max_residues:
                break
            mono_type_ids[b, i] = ResidueTypeEmbeddings.get_mono_type_id(name)
    
    return mono_type_ids


if __name__ == '__main__':
    # Test the enhanced classifier
    print("Testing EnhancedGlycanClassifier...")
    
    config = MultimodalGlycanBERTConfig(use_cnn_frontend=True)
    bert = MultimodalGlycanBERT(config)
    
    classifier = EnhancedGlycanClassifier(
        bert=bert,
        num_classes=31,  # species task
        use_mono_pooling=True,
        use_residue_types=True,
    )
    
    # Create dummy input
    batch_size = 2
    seq_len = 64
    token_ids = torch.randint(0, 100, (batch_size, seq_len))
    attention_mask = torch.ones(batch_size, seq_len)
    residue_ids = torch.div(torch.arange(seq_len), 10, rounding_mode='floor').unsqueeze(0).expand(batch_size, -1)
    mono_type_ids = torch.randint(0, 20, (batch_size, 10))
    
    logits = classifier(token_ids, attention_mask, residue_ids, mono_type_ids)
    
    print(f"✅ Output shape: {logits.shape}")
    print(f"✅ Total params: {sum(p.numel() for p in classifier.parameters()):,}")