Polish Affinose release source
Browse files- SHA256SUMS +4 -4
- src/affinose_dataset.py +6 -7
- src/affinose_model.py +3 -3
- src/bertose_layers.py +4 -6
- src/bertose_model.py +19 -20
SHA256SUMS
CHANGED
|
@@ -3,10 +3,10 @@ f474c23adc30c94a8f6867c8260213eddb07c9732ab285078f2f08a9ad9fd062 ./README.md
|
|
| 3 |
533fe4f9317782e39ad7980caa0f47ad92f1be999dd9489425a9429e2e7c15cb ./checkpoints/affinose_interaction_model.pt
|
| 4 |
043fcb1c7eb97e22fefe8fadadeff97b56ede254b95e318553175837a1e57114 ./config.json
|
| 5 |
1be88f15fd905882c711f7ceb59b619a93f9cee2c6c7c031f8dbabd35b29e9e0 ./requirements.txt
|
| 6 |
-
|
| 7 |
6f83790933a2e4f10160abdca9a6db1a4407ded780abf1d71407963a556dedb6 ./src/affinose_inference.py
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
|
| 12 |
6a572afdf53f1494ab96c896876b824ca7ea749777352606aa9f96bf270ceecc ./vocab/bpe_vocabulary.json
|
|
|
|
| 3 |
533fe4f9317782e39ad7980caa0f47ad92f1be999dd9489425a9429e2e7c15cb ./checkpoints/affinose_interaction_model.pt
|
| 4 |
043fcb1c7eb97e22fefe8fadadeff97b56ede254b95e318553175837a1e57114 ./config.json
|
| 5 |
1be88f15fd905882c711f7ceb59b619a93f9cee2c6c7c031f8dbabd35b29e9e0 ./requirements.txt
|
| 6 |
+
175811ef7be90383787858fafbb929d7a87e039bc0185d4d0f6d216dc92a48ed ./src/affinose_dataset.py
|
| 7 |
6f83790933a2e4f10160abdca9a6db1a4407ded780abf1d71407963a556dedb6 ./src/affinose_inference.py
|
| 8 |
+
de14de370a77237ef2ac5c88714c83a54ce3f696efb710f93f8be19106c7fa95 ./src/affinose_model.py
|
| 9 |
+
6362da8e8de0dc4d580c7d94ef6ab1dbc737da13127fc4078681ce6315180086 ./src/bertose_layers.py
|
| 10 |
+
3c5b826fcf5850749f74d980eee48d0595557f3d6e2a58aa873902817eb65c64 ./src/bertose_model.py
|
| 11 |
0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
|
| 12 |
6a572afdf53f1494ab96c896876b824ca7ea749777352606aa9f96bf270ceecc ./vocab/bpe_vocabulary.json
|
src/affinose_dataset.py
CHANGED
|
@@ -48,16 +48,15 @@ def load_bpe_tokenizer(vocab_path: str):
|
|
| 48 |
|
| 49 |
utils_dir = None
|
| 50 |
for root in candidate_roots:
|
| 51 |
-
candidate
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
break
|
| 55 |
|
| 56 |
if utils_dir is None:
|
| 57 |
-
utils_dir = Path(
|
| 58 |
-
"/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/"
|
| 59 |
-
"v3.1_cluster_training/bert_training_v4/downstream_tasks/utils"
|
| 60 |
-
)
|
| 61 |
|
| 62 |
utils_dir = str(utils_dir)
|
| 63 |
if utils_dir not in sys.path:
|
|
|
|
| 48 |
|
| 49 |
utils_dir = None
|
| 50 |
for root in candidate_roots:
|
| 51 |
+
for candidate in [root, root / "src", root / "downstream_tasks" / "utils"]:
|
| 52 |
+
if (candidate / "wurcs_bpe_tokenizer.py").exists():
|
| 53 |
+
utils_dir = candidate
|
| 54 |
+
break
|
| 55 |
+
if utils_dir is not None:
|
| 56 |
break
|
| 57 |
|
| 58 |
if utils_dir is None:
|
| 59 |
+
utils_dir = Path(__file__).resolve().parent
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
utils_dir = str(utils_dir)
|
| 62 |
if utils_dir not in sys.path:
|
src/affinose_model.py
CHANGED
|
@@ -40,10 +40,10 @@ def _default_bertose_root() -> Path:
|
|
| 40 |
|
| 41 |
here = Path(__file__).resolve()
|
| 42 |
for parent in here.parents:
|
| 43 |
-
if (parent / "
|
| 44 |
return parent
|
| 45 |
|
| 46 |
-
return
|
| 47 |
|
| 48 |
|
| 49 |
BERTOSE_ROOT = _default_bertose_root()
|
|
@@ -55,7 +55,7 @@ def _ensure_bertose_imports():
|
|
| 55 |
roots = [
|
| 56 |
str(source_dir),
|
| 57 |
str(BERTOSE_ROOT),
|
| 58 |
-
str(BERTOSE_ROOT / "
|
| 59 |
]
|
| 60 |
for root in roots:
|
| 61 |
if root not in sys.path:
|
|
|
|
| 40 |
|
| 41 |
here = Path(__file__).resolve()
|
| 42 |
for parent in here.parents:
|
| 43 |
+
if (parent / "src").exists() or (parent / "bertose_model.py").exists():
|
| 44 |
return parent
|
| 45 |
|
| 46 |
+
return here.parent
|
| 47 |
|
| 48 |
|
| 49 |
BERTOSE_ROOT = _default_bertose_root()
|
|
|
|
| 55 |
roots = [
|
| 56 |
str(source_dir),
|
| 57 |
str(BERTOSE_ROOT),
|
| 58 |
+
str(BERTOSE_ROOT / "src"),
|
| 59 |
]
|
| 60 |
for root in roots:
|
| 61 |
if root not in sys.path:
|
src/bertose_layers.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
|
| 4 |
-
Transformer
|
| 5 |
-
Based on BERT/ESM2 architecture adapted for atomic glycan tokenization.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import torch
|
|
@@ -11,7 +10,7 @@ import math
|
|
| 11 |
|
| 12 |
|
| 13 |
class GlycanBERTConfig:
|
| 14 |
-
"""Configuration for
|
| 15 |
|
| 16 |
def __init__(
|
| 17 |
self,
|
|
@@ -203,7 +202,7 @@ class GlycanBERTLayer(nn.Module):
|
|
| 203 |
|
| 204 |
class GlycanBERT(nn.Module):
|
| 205 |
"""
|
| 206 |
-
|
| 207 |
"""
|
| 208 |
|
| 209 |
def __init__(self, config: GlycanBERTConfig):
|
|
@@ -300,4 +299,3 @@ class GlycanBERT(nn.Module):
|
|
| 300 |
hidden_states = layer(hidden_states, attention_mask)
|
| 301 |
|
| 302 |
return hidden_states
|
| 303 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
+
Bertose transformer layers.
|
| 3 |
|
| 4 |
+
Transformer blocks adapted for WURCS glycan tokenization.
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class GlycanBERTConfig:
|
| 13 |
+
"""Configuration for the Bertose transformer stack."""
|
| 14 |
|
| 15 |
def __init__(
|
| 16 |
self,
|
|
|
|
| 202 |
|
| 203 |
class GlycanBERT(nn.Module):
|
| 204 |
"""
|
| 205 |
+
Bertose transformer stack for masked language modeling.
|
| 206 |
"""
|
| 207 |
|
| 208 |
def __init__(self, config: GlycanBERTConfig):
|
|
|
|
| 299 |
hidden_states = layer(hidden_states, attention_mask)
|
| 300 |
|
| 301 |
return hidden_states
|
|
|
src/bertose_model.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
|
| 4 |
-
|
| 5 |
- Sequence (WURCS atomic tokenization)
|
| 6 |
- MS (mass spectrometry peaks, RT, intensity)
|
| 7 |
- 3D structure (VQ-VAE discrete tokens, 4 per residue)
|
|
@@ -40,11 +40,11 @@ class ConvGlycanBERTEmbeddings(nn.Module):
|
|
| 40 |
config.max_position_embeddings, config.hidden_size
|
| 41 |
)
|
| 42 |
|
| 43 |
-
#
|
| 44 |
max_branch_depth = getattr(config, "max_branch_depth", 8)
|
| 45 |
self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size)
|
| 46 |
|
| 47 |
-
#
|
| 48 |
# 0=none, 1=1-3, 2=1-4, 3=1-6, etc.
|
| 49 |
num_linkage_types = getattr(config, "num_linkage_types", 9)
|
| 50 |
self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size)
|
|
@@ -83,13 +83,13 @@ class ConvGlycanBERTEmbeddings(nn.Module):
|
|
| 83 |
position_ids = self.position_ids[:, :seq_len]
|
| 84 |
x = x + self.position_embeddings(position_ids)
|
| 85 |
|
| 86 |
-
#
|
| 87 |
if branch_depths is not None:
|
| 88 |
# Clamp to valid range
|
| 89 |
branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1)
|
| 90 |
x = x + self.branch_embeddings(branch_depths)
|
| 91 |
|
| 92 |
-
#
|
| 93 |
if linkage_types is not None:
|
| 94 |
linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1)
|
| 95 |
x = x + self.linkage_embeddings(linkage_types)
|
|
@@ -148,7 +148,7 @@ def create_residue_level_mask(
|
|
| 148 |
|
| 149 |
|
| 150 |
class MultimodalGlycanBERTConfig:
|
| 151 |
-
"""Configuration for
|
| 152 |
|
| 153 |
def __init__(
|
| 154 |
self,
|
|
@@ -247,7 +247,7 @@ class MultimodalGlycanBERTConfig:
|
|
| 247 |
self.seq_loss_weight = seq_loss_weight
|
| 248 |
self.ms_loss_weight = ms_loss_weight
|
| 249 |
self.struct_loss_weight = struct_loss_weight
|
| 250 |
-
self.dist_loss_weight = 0.25
|
| 251 |
|
| 252 |
# Token IDs
|
| 253 |
self.pad_token_id = pad_token_id
|
|
@@ -610,7 +610,7 @@ class CrossAttentionLayer(nn.Module):
|
|
| 610 |
|
| 611 |
class MultimodalGlycanBERT(nn.Module):
|
| 612 |
"""
|
| 613 |
-
|
| 614 |
|
| 615 |
Architecture:
|
| 616 |
1. Separate encoders for each modality (sequence, MS, 3D structure)
|
|
@@ -628,7 +628,7 @@ class MultimodalGlycanBERT(nn.Module):
|
|
| 628 |
seq_config.cnn_kernel_size = config.cnn_kernel_size
|
| 629 |
|
| 630 |
if config.use_cnn_frontend:
|
| 631 |
-
print(f"
|
| 632 |
self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config)
|
| 633 |
else:
|
| 634 |
self.seq_embeddings = GlycanBERTEmbeddings(seq_config)
|
|
@@ -675,7 +675,7 @@ class MultimodalGlycanBERT(nn.Module):
|
|
| 675 |
)
|
| 676 |
|
| 677 |
# ===== Distance Prediction Head (Topology) =====
|
| 678 |
-
#
|
| 679 |
# (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x
|
| 680 |
self.dist_proj = nn.Linear(config.seq_hidden_size, 128)
|
| 681 |
self.distance_head = nn.Sequential(
|
|
@@ -706,8 +706,8 @@ class MultimodalGlycanBERT(nn.Module):
|
|
| 706 |
seq_token_ids: torch.Tensor,
|
| 707 |
seq_attention_mask: torch.Tensor,
|
| 708 |
seq_residue_ids: torch.Tensor,
|
| 709 |
-
seq_branch_depths: Optional[torch.Tensor] = None,
|
| 710 |
-
seq_linkage_types: Optional[torch.Tensor] = None,
|
| 711 |
ms_token_ids: torch.Tensor = None,
|
| 712 |
ms_attention_mask: torch.Tensor = None,
|
| 713 |
has_ms: torch.Tensor = None,
|
|
@@ -718,11 +718,11 @@ class MultimodalGlycanBERT(nn.Module):
|
|
| 718 |
seq_labels: Optional[torch.Tensor] = None,
|
| 719 |
ms_labels: Optional[torch.Tensor] = None,
|
| 720 |
struct_labels: Optional[torch.Tensor] = None,
|
| 721 |
-
dist_labels: Optional[torch.Tensor] = None,
|
| 722 |
return_dict: bool = True,
|
| 723 |
) -> Dict[str, torch.Tensor]:
|
| 724 |
"""
|
| 725 |
-
Forward pass for
|
| 726 |
|
| 727 |
Args:
|
| 728 |
seq_token_ids: (batch_size, seq_len) - Sequence token IDs
|
|
@@ -839,7 +839,7 @@ class MultimodalGlycanBERT(nn.Module):
|
|
| 839 |
seq_loss = None
|
| 840 |
ms_loss = None
|
| 841 |
struct_loss = None
|
| 842 |
-
dist_loss = None
|
| 843 |
|
| 844 |
if seq_labels is not None:
|
| 845 |
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
|
@@ -921,11 +921,11 @@ class MultimodalGlycanBERT(nn.Module):
|
|
| 921 |
'seq_loss': seq_loss,
|
| 922 |
'ms_loss': ms_loss,
|
| 923 |
'struct_loss': struct_loss,
|
| 924 |
-
'dist_loss': dist_loss,
|
| 925 |
'seq_logits': seq_logits,
|
| 926 |
'ms_logits': ms_logits,
|
| 927 |
'struct_logits': struct_logits,
|
| 928 |
-
'dist_predictions': dist_predictions,
|
| 929 |
'seq_hidden': seq_hidden,
|
| 930 |
'ms_hidden': ms_hidden,
|
| 931 |
'struct_hidden': struct_hidden,
|
|
@@ -970,7 +970,7 @@ class MultimodalGlycanBERT(nn.Module):
|
|
| 970 |
if __name__ == "__main__":
|
| 971 |
# Test the model
|
| 972 |
print("="*80)
|
| 973 |
-
print("Testing
|
| 974 |
print("="*80)
|
| 975 |
|
| 976 |
# Create config
|
|
@@ -1081,4 +1081,3 @@ if __name__ == "__main__":
|
|
| 1081 |
print(f"\n{'='*80}")
|
| 1082 |
print("Model Test Complete!")
|
| 1083 |
print("="*80)
|
| 1084 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
+
Bertose model
|
| 3 |
|
| 4 |
+
Core glycan representation model with three modalities:
|
| 5 |
- Sequence (WURCS atomic tokenization)
|
| 6 |
- MS (mass spectrometry peaks, RT, intensity)
|
| 7 |
- 3D structure (VQ-VAE discrete tokens, 4 per residue)
|
|
|
|
| 40 |
config.max_position_embeddings, config.hidden_size
|
| 41 |
)
|
| 42 |
|
| 43 |
+
# Branch depth embeddings encode depth in the glycan tree.
|
| 44 |
max_branch_depth = getattr(config, "max_branch_depth", 8)
|
| 45 |
self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size)
|
| 46 |
|
| 47 |
+
# Linkage type embeddings encode glycosidic bond chemistry.
|
| 48 |
# 0=none, 1=1-3, 2=1-4, 3=1-6, etc.
|
| 49 |
num_linkage_types = getattr(config, "num_linkage_types", 9)
|
| 50 |
self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size)
|
|
|
|
| 83 |
position_ids = self.position_ids[:, :seq_len]
|
| 84 |
x = x + self.position_embeddings(position_ids)
|
| 85 |
|
| 86 |
+
# Add branch depth embeddings.
|
| 87 |
if branch_depths is not None:
|
| 88 |
# Clamp to valid range
|
| 89 |
branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1)
|
| 90 |
x = x + self.branch_embeddings(branch_depths)
|
| 91 |
|
| 92 |
+
# Add linkage type embeddings.
|
| 93 |
if linkage_types is not None:
|
| 94 |
linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1)
|
| 95 |
x = x + self.linkage_embeddings(linkage_types)
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
class MultimodalGlycanBERTConfig:
|
| 151 |
+
"""Configuration for the Bertose model."""
|
| 152 |
|
| 153 |
def __init__(
|
| 154 |
self,
|
|
|
|
| 247 |
self.seq_loss_weight = seq_loss_weight
|
| 248 |
self.ms_loss_weight = ms_loss_weight
|
| 249 |
self.struct_loss_weight = struct_loss_weight
|
| 250 |
+
self.dist_loss_weight = 0.25
|
| 251 |
|
| 252 |
# Token IDs
|
| 253 |
self.pad_token_id = pad_token_id
|
|
|
|
| 610 |
|
| 611 |
class MultimodalGlycanBERT(nn.Module):
|
| 612 |
"""
|
| 613 |
+
Bertose model for glycan representation learning.
|
| 614 |
|
| 615 |
Architecture:
|
| 616 |
1. Separate encoders for each modality (sequence, MS, 3D structure)
|
|
|
|
| 628 |
seq_config.cnn_kernel_size = config.cnn_kernel_size
|
| 629 |
|
| 630 |
if config.use_cnn_frontend:
|
| 631 |
+
print(f"Enabled convolutional front-end (kernel={config.cnn_kernel_size})")
|
| 632 |
self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config)
|
| 633 |
else:
|
| 634 |
self.seq_embeddings = GlycanBERTEmbeddings(seq_config)
|
|
|
|
| 675 |
)
|
| 676 |
|
| 677 |
# ===== Distance Prediction Head (Topology) =====
|
| 678 |
+
# Project down to 128 dimensions first to reduce memory use.
|
| 679 |
# (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x
|
| 680 |
self.dist_proj = nn.Linear(config.seq_hidden_size, 128)
|
| 681 |
self.distance_head = nn.Sequential(
|
|
|
|
| 706 |
seq_token_ids: torch.Tensor,
|
| 707 |
seq_attention_mask: torch.Tensor,
|
| 708 |
seq_residue_ids: torch.Tensor,
|
| 709 |
+
seq_branch_depths: Optional[torch.Tensor] = None,
|
| 710 |
+
seq_linkage_types: Optional[torch.Tensor] = None,
|
| 711 |
ms_token_ids: torch.Tensor = None,
|
| 712 |
ms_attention_mask: torch.Tensor = None,
|
| 713 |
has_ms: torch.Tensor = None,
|
|
|
|
| 718 |
seq_labels: Optional[torch.Tensor] = None,
|
| 719 |
ms_labels: Optional[torch.Tensor] = None,
|
| 720 |
struct_labels: Optional[torch.Tensor] = None,
|
| 721 |
+
dist_labels: Optional[torch.Tensor] = None,
|
| 722 |
return_dict: bool = True,
|
| 723 |
) -> Dict[str, torch.Tensor]:
|
| 724 |
"""
|
| 725 |
+
Forward pass for Bertose.
|
| 726 |
|
| 727 |
Args:
|
| 728 |
seq_token_ids: (batch_size, seq_len) - Sequence token IDs
|
|
|
|
| 839 |
seq_loss = None
|
| 840 |
ms_loss = None
|
| 841 |
struct_loss = None
|
| 842 |
+
dist_loss = None
|
| 843 |
|
| 844 |
if seq_labels is not None:
|
| 845 |
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
|
|
|
| 921 |
'seq_loss': seq_loss,
|
| 922 |
'ms_loss': ms_loss,
|
| 923 |
'struct_loss': struct_loss,
|
| 924 |
+
'dist_loss': dist_loss,
|
| 925 |
'seq_logits': seq_logits,
|
| 926 |
'ms_logits': ms_logits,
|
| 927 |
'struct_logits': struct_logits,
|
| 928 |
+
'dist_predictions': dist_predictions,
|
| 929 |
'seq_hidden': seq_hidden,
|
| 930 |
'ms_hidden': ms_hidden,
|
| 931 |
'struct_hidden': struct_hidden,
|
|
|
|
| 970 |
if __name__ == "__main__":
|
| 971 |
# Test the model
|
| 972 |
print("="*80)
|
| 973 |
+
print("Testing Bertose model")
|
| 974 |
print("="*80)
|
| 975 |
|
| 976 |
# Create config
|
|
|
|
| 1081 |
print(f"\n{'='*80}")
|
| 1082 |
print("Model Test Complete!")
|
| 1083 |
print("="*80)
|
|
|