supanthadey1 commited on
Commit
399a070
·
verified ·
1 Parent(s): 8dc2602

Polish Affinose release source

Browse files
SHA256SUMS CHANGED
@@ -3,10 +3,10 @@ f474c23adc30c94a8f6867c8260213eddb07c9732ab285078f2f08a9ad9fd062 ./README.md
3
  533fe4f9317782e39ad7980caa0f47ad92f1be999dd9489425a9429e2e7c15cb ./checkpoints/affinose_interaction_model.pt
4
  043fcb1c7eb97e22fefe8fadadeff97b56ede254b95e318553175837a1e57114 ./config.json
5
  1be88f15fd905882c711f7ceb59b619a93f9cee2c6c7c031f8dbabd35b29e9e0 ./requirements.txt
6
- 3efc954aff46a0017146e60a803f26b7fa15ded4bcba66653ae443456ae5ab2d ./src/affinose_dataset.py
7
  6f83790933a2e4f10160abdca9a6db1a4407ded780abf1d71407963a556dedb6 ./src/affinose_inference.py
8
- eda11ff25cf5bf34ebcb1dd0a81c9184314c66ef9664f46e8824f6a1e7769217 ./src/affinose_model.py
9
- b69f14c9976951325e3a0a4e8107a16126e67d410e966650f513f1f538a732bb ./src/bertose_layers.py
10
- f247a6c09132a61cb649acfe022b269b5b94c37a5069fcb62045f3340b96b191 ./src/bertose_model.py
11
  0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
12
  6a572afdf53f1494ab96c896876b824ca7ea749777352606aa9f96bf270ceecc ./vocab/bpe_vocabulary.json
 
3
  533fe4f9317782e39ad7980caa0f47ad92f1be999dd9489425a9429e2e7c15cb ./checkpoints/affinose_interaction_model.pt
4
  043fcb1c7eb97e22fefe8fadadeff97b56ede254b95e318553175837a1e57114 ./config.json
5
  1be88f15fd905882c711f7ceb59b619a93f9cee2c6c7c031f8dbabd35b29e9e0 ./requirements.txt
6
+ 175811ef7be90383787858fafbb929d7a87e039bc0185d4d0f6d216dc92a48ed ./src/affinose_dataset.py
7
  6f83790933a2e4f10160abdca9a6db1a4407ded780abf1d71407963a556dedb6 ./src/affinose_inference.py
8
+ de14de370a77237ef2ac5c88714c83a54ce3f696efb710f93f8be19106c7fa95 ./src/affinose_model.py
9
+ 6362da8e8de0dc4d580c7d94ef6ab1dbc737da13127fc4078681ce6315180086 ./src/bertose_layers.py
10
+ 3c5b826fcf5850749f74d980eee48d0595557f3d6e2a58aa873902817eb65c64 ./src/bertose_model.py
11
  0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
12
  6a572afdf53f1494ab96c896876b824ca7ea749777352606aa9f96bf270ceecc ./vocab/bpe_vocabulary.json
src/affinose_dataset.py CHANGED
@@ -48,16 +48,15 @@ def load_bpe_tokenizer(vocab_path: str):
48
 
49
  utils_dir = None
50
  for root in candidate_roots:
51
- candidate = root / "bert_training_v4" / "downstream_tasks" / "utils"
52
- if candidate.exists():
53
- utils_dir = candidate
 
 
54
  break
55
 
56
  if utils_dir is None:
57
- utils_dir = Path(
58
- "/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/"
59
- "v3.1_cluster_training/bert_training_v4/downstream_tasks/utils"
60
- )
61
 
62
  utils_dir = str(utils_dir)
63
  if utils_dir not in sys.path:
 
48
 
49
  utils_dir = None
50
  for root in candidate_roots:
51
+ for candidate in [root, root / "src", root / "downstream_tasks" / "utils"]:
52
+ if (candidate / "wurcs_bpe_tokenizer.py").exists():
53
+ utils_dir = candidate
54
+ break
55
+ if utils_dir is not None:
56
  break
57
 
58
  if utils_dir is None:
59
+ utils_dir = Path(__file__).resolve().parent
 
 
 
60
 
61
  utils_dir = str(utils_dir)
62
  if utils_dir not in sys.path:
src/affinose_model.py CHANGED
@@ -40,10 +40,10 @@ def _default_bertose_root() -> Path:
40
 
41
  here = Path(__file__).resolve()
42
  for parent in here.parents:
43
- if (parent / "bert_training_v4").exists() and (parent / "model").exists():
44
  return parent
45
 
46
- return Path("/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training")
47
 
48
 
49
  BERTOSE_ROOT = _default_bertose_root()
@@ -55,7 +55,7 @@ def _ensure_bertose_imports():
55
  roots = [
56
  str(source_dir),
57
  str(BERTOSE_ROOT),
58
- str(BERTOSE_ROOT / "bert_training_v4"),
59
  ]
60
  for root in roots:
61
  if root not in sys.path:
 
40
 
41
  here = Path(__file__).resolve()
42
  for parent in here.parents:
43
+ if (parent / "src").exists() or (parent / "bertose_model.py").exists():
44
  return parent
45
 
46
+ return here.parent
47
 
48
 
49
  BERTOSE_ROOT = _default_bertose_root()
 
55
  roots = [
56
  str(source_dir),
57
  str(BERTOSE_ROOT),
58
+ str(BERTOSE_ROOT / "src"),
59
  ]
60
  for root in roots:
61
  if root not in sys.path:
src/bertose_layers.py CHANGED
@@ -1,8 +1,7 @@
1
  """
2
- Glycan BERT Model
3
 
4
- Transformer-based masked language model for glycan structures.
5
- Based on BERT/ESM2 architecture adapted for atomic glycan tokenization.
6
  """
7
 
8
  import torch
@@ -11,7 +10,7 @@ import math
11
 
12
 
13
  class GlycanBERTConfig:
14
- """Configuration for GlycanBERT."""
15
 
16
  def __init__(
17
  self,
@@ -203,7 +202,7 @@ class GlycanBERTLayer(nn.Module):
203
 
204
  class GlycanBERT(nn.Module):
205
  """
206
- Glycan BERT model for masked language modeling.
207
  """
208
 
209
  def __init__(self, config: GlycanBERTConfig):
@@ -300,4 +299,3 @@ class GlycanBERT(nn.Module):
300
  hidden_states = layer(hidden_states, attention_mask)
301
 
302
  return hidden_states
303
-
 
1
  """
2
+ Bertose transformer layers.
3
 
4
+ Transformer blocks adapted for WURCS glycan tokenization.
 
5
  """
6
 
7
  import torch
 
10
 
11
 
12
  class GlycanBERTConfig:
13
+ """Configuration for the Bertose transformer stack."""
14
 
15
  def __init__(
16
  self,
 
202
 
203
  class GlycanBERT(nn.Module):
204
  """
205
+ Bertose transformer stack for masked language modeling.
206
  """
207
 
208
  def __init__(self, config: GlycanBERTConfig):
 
299
  hidden_states = layer(hidden_states, attention_mask)
300
 
301
  return hidden_states
 
src/bertose_model.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
- Multimodal Glycan BERT Model v3
3
 
4
- Extends GlycanBERT to handle three modalities:
5
  - Sequence (WURCS atomic tokenization)
6
  - MS (mass spectrometry peaks, RT, intensity)
7
  - 3D structure (VQ-VAE discrete tokens, 4 per residue)
@@ -40,11 +40,11 @@ class ConvGlycanBERTEmbeddings(nn.Module):
40
  config.max_position_embeddings, config.hidden_size
41
  )
42
 
43
- # NEW: Branch depth embeddings - encodes depth in glycan tree (0=root, 1=child, etc.)
44
  max_branch_depth = getattr(config, "max_branch_depth", 8)
45
  self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size)
46
 
47
- # NEW: Linkage type embeddings - encodes chemistry of glycosidic bond
48
  # 0=none, 1=1-3, 2=1-4, 3=1-6, etc.
49
  num_linkage_types = getattr(config, "num_linkage_types", 9)
50
  self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size)
@@ -83,13 +83,13 @@ class ConvGlycanBERTEmbeddings(nn.Module):
83
  position_ids = self.position_ids[:, :seq_len]
84
  x = x + self.position_embeddings(position_ids)
85
 
86
- # NEW: Add branch depth embeddings (encodes tree structure)
87
  if branch_depths is not None:
88
  # Clamp to valid range
89
  branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1)
90
  x = x + self.branch_embeddings(branch_depths)
91
 
92
- # NEW: Add linkage type embeddings (encodes bond chemistry)
93
  if linkage_types is not None:
94
  linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1)
95
  x = x + self.linkage_embeddings(linkage_types)
@@ -148,7 +148,7 @@ def create_residue_level_mask(
148
 
149
 
150
  class MultimodalGlycanBERTConfig:
151
- """Configuration for Multimodal GlycanBERT v3."""
152
 
153
  def __init__(
154
  self,
@@ -247,7 +247,7 @@ class MultimodalGlycanBERTConfig:
247
  self.seq_loss_weight = seq_loss_weight
248
  self.ms_loss_weight = ms_loss_weight
249
  self.struct_loss_weight = struct_loss_weight
250
- self.dist_loss_weight = 0.25 # NEW: Topology loss weight (default, can override from config)
251
 
252
  # Token IDs
253
  self.pad_token_id = pad_token_id
@@ -610,7 +610,7 @@ class CrossAttentionLayer(nn.Module):
610
 
611
  class MultimodalGlycanBERT(nn.Module):
612
  """
613
- Multimodal BERT for glycan representation learning (v3).
614
 
615
  Architecture:
616
  1. Separate encoders for each modality (sequence, MS, 3D structure)
@@ -628,7 +628,7 @@ class MultimodalGlycanBERT(nn.Module):
628
  seq_config.cnn_kernel_size = config.cnn_kernel_size
629
 
630
  if config.use_cnn_frontend:
631
- print(f"Enabled Convolutional Front-End (Kernel={config.cnn_kernel_size})")
632
  self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config)
633
  else:
634
  self.seq_embeddings = GlycanBERTEmbeddings(seq_config)
@@ -675,7 +675,7 @@ class MultimodalGlycanBERT(nn.Module):
675
  )
676
 
677
  # ===== Distance Prediction Head (Topology) =====
678
- # OPTIMIZED: Project down to 128 dim first to save GPU memory
679
  # (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x
680
  self.dist_proj = nn.Linear(config.seq_hidden_size, 128)
681
  self.distance_head = nn.Sequential(
@@ -706,8 +706,8 @@ class MultimodalGlycanBERT(nn.Module):
706
  seq_token_ids: torch.Tensor,
707
  seq_attention_mask: torch.Tensor,
708
  seq_residue_ids: torch.Tensor,
709
- seq_branch_depths: Optional[torch.Tensor] = None, # NEW: Branch depths
710
- seq_linkage_types: Optional[torch.Tensor] = None, # NEW: Linkage types
711
  ms_token_ids: torch.Tensor = None,
712
  ms_attention_mask: torch.Tensor = None,
713
  has_ms: torch.Tensor = None,
@@ -718,11 +718,11 @@ class MultimodalGlycanBERT(nn.Module):
718
  seq_labels: Optional[torch.Tensor] = None,
719
  ms_labels: Optional[torch.Tensor] = None,
720
  struct_labels: Optional[torch.Tensor] = None,
721
- dist_labels: Optional[torch.Tensor] = None, # NEW: Topology distance labels
722
  return_dict: bool = True,
723
  ) -> Dict[str, torch.Tensor]:
724
  """
725
- Forward pass for multimodal BERT v3.
726
 
727
  Args:
728
  seq_token_ids: (batch_size, seq_len) - Sequence token IDs
@@ -839,7 +839,7 @@ class MultimodalGlycanBERT(nn.Module):
839
  seq_loss = None
840
  ms_loss = None
841
  struct_loss = None
842
- dist_loss = None # NEW: Topology distance loss
843
 
844
  if seq_labels is not None:
845
  loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
@@ -921,11 +921,11 @@ class MultimodalGlycanBERT(nn.Module):
921
  'seq_loss': seq_loss,
922
  'ms_loss': ms_loss,
923
  'struct_loss': struct_loss,
924
- 'dist_loss': dist_loss, # NEW: Topology loss
925
  'seq_logits': seq_logits,
926
  'ms_logits': ms_logits,
927
  'struct_logits': struct_logits,
928
- 'dist_predictions': dist_predictions, # NEW: Distance predictions
929
  'seq_hidden': seq_hidden,
930
  'ms_hidden': ms_hidden,
931
  'struct_hidden': struct_hidden,
@@ -970,7 +970,7 @@ class MultimodalGlycanBERT(nn.Module):
970
  if __name__ == "__main__":
971
  # Test the model
972
  print("="*80)
973
- print("Testing Multimodal GlycanBERT v3")
974
  print("="*80)
975
 
976
  # Create config
@@ -1081,4 +1081,3 @@ if __name__ == "__main__":
1081
  print(f"\n{'='*80}")
1082
  print("Model Test Complete!")
1083
  print("="*80)
1084
-
 
1
  """
2
+ Bertose model
3
 
4
+ Core glycan representation model with three modalities:
5
  - Sequence (WURCS atomic tokenization)
6
  - MS (mass spectrometry peaks, RT, intensity)
7
  - 3D structure (VQ-VAE discrete tokens, 4 per residue)
 
40
  config.max_position_embeddings, config.hidden_size
41
  )
42
 
43
+ # Branch depth embeddings encode depth in the glycan tree.
44
  max_branch_depth = getattr(config, "max_branch_depth", 8)
45
  self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size)
46
 
47
+ # Linkage type embeddings encode glycosidic bond chemistry.
48
  # 0=none, 1=1-3, 2=1-4, 3=1-6, etc.
49
  num_linkage_types = getattr(config, "num_linkage_types", 9)
50
  self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size)
 
83
  position_ids = self.position_ids[:, :seq_len]
84
  x = x + self.position_embeddings(position_ids)
85
 
86
+ # Add branch depth embeddings.
87
  if branch_depths is not None:
88
  # Clamp to valid range
89
  branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1)
90
  x = x + self.branch_embeddings(branch_depths)
91
 
92
+ # Add linkage type embeddings.
93
  if linkage_types is not None:
94
  linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1)
95
  x = x + self.linkage_embeddings(linkage_types)
 
148
 
149
 
150
  class MultimodalGlycanBERTConfig:
151
+ """Configuration for the Bertose model."""
152
 
153
  def __init__(
154
  self,
 
247
  self.seq_loss_weight = seq_loss_weight
248
  self.ms_loss_weight = ms_loss_weight
249
  self.struct_loss_weight = struct_loss_weight
250
+ self.dist_loss_weight = 0.25
251
 
252
  # Token IDs
253
  self.pad_token_id = pad_token_id
 
610
 
611
  class MultimodalGlycanBERT(nn.Module):
612
  """
613
+ Bertose model for glycan representation learning.
614
 
615
  Architecture:
616
  1. Separate encoders for each modality (sequence, MS, 3D structure)
 
628
  seq_config.cnn_kernel_size = config.cnn_kernel_size
629
 
630
  if config.use_cnn_frontend:
631
+ print(f"Enabled convolutional front-end (kernel={config.cnn_kernel_size})")
632
  self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config)
633
  else:
634
  self.seq_embeddings = GlycanBERTEmbeddings(seq_config)
 
675
  )
676
 
677
  # ===== Distance Prediction Head (Topology) =====
678
+ # Project down to 128 dimensions first to reduce memory use.
679
  # (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x
680
  self.dist_proj = nn.Linear(config.seq_hidden_size, 128)
681
  self.distance_head = nn.Sequential(
 
706
  seq_token_ids: torch.Tensor,
707
  seq_attention_mask: torch.Tensor,
708
  seq_residue_ids: torch.Tensor,
709
+ seq_branch_depths: Optional[torch.Tensor] = None,
710
+ seq_linkage_types: Optional[torch.Tensor] = None,
711
  ms_token_ids: torch.Tensor = None,
712
  ms_attention_mask: torch.Tensor = None,
713
  has_ms: torch.Tensor = None,
 
718
  seq_labels: Optional[torch.Tensor] = None,
719
  ms_labels: Optional[torch.Tensor] = None,
720
  struct_labels: Optional[torch.Tensor] = None,
721
+ dist_labels: Optional[torch.Tensor] = None,
722
  return_dict: bool = True,
723
  ) -> Dict[str, torch.Tensor]:
724
  """
725
+ Forward pass for Bertose.
726
 
727
  Args:
728
  seq_token_ids: (batch_size, seq_len) - Sequence token IDs
 
839
  seq_loss = None
840
  ms_loss = None
841
  struct_loss = None
842
+ dist_loss = None
843
 
844
  if seq_labels is not None:
845
  loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
 
921
  'seq_loss': seq_loss,
922
  'ms_loss': ms_loss,
923
  'struct_loss': struct_loss,
924
+ 'dist_loss': dist_loss,
925
  'seq_logits': seq_logits,
926
  'ms_logits': ms_logits,
927
  'struct_logits': struct_logits,
928
+ 'dist_predictions': dist_predictions,
929
  'seq_hidden': seq_hidden,
930
  'ms_hidden': ms_hidden,
931
  'struct_hidden': struct_hidden,
 
970
  if __name__ == "__main__":
971
  # Test the model
972
  print("="*80)
973
+ print("Testing Bertose model")
974
  print("="*80)
975
 
976
  # Create config
 
1081
  print(f"\n{'='*80}")
1082
  print("Model Test Complete!")
1083
  print("="*80)