manpreet88 commited on
Commit
6a3741a
·
1 Parent(s): 949da3f

Update CL.py

Browse files
Files changed (1) hide show
  1. PolyFusion/CL.py +26 -421
PolyFusion/CL.py CHANGED
@@ -22,19 +22,14 @@ import torch.nn as nn
22
  import torch.nn.functional as F
23
  from torch.utils.data import Dataset, DataLoader
24
 
25
- # PyG building blocks
26
- try:
27
- from torch_geometric.nn import GINEConv
28
- from torch_geometric.nn.models import SchNet as PyGSchNet
29
- from torch_geometric.nn import radius_graph
30
- except Exception as e:
31
- # we keep imports guarded — if these fail, user will get a clear error later
32
- GINEConv = None
33
- PyGSchNet = None
34
- radius_graph = None
35
 
36
  # HF Trainer & Transformers
37
- from transformers import TrainingArguments, Trainer, DebertaV2ForMaskedLM, DebertaV2Tokenizer
38
  from transformers import DataCollatorForLanguageModeling
39
  from transformers.trainer_callback import TrainerCallback
40
 
@@ -42,7 +37,7 @@ from sklearn.model_selection import train_test_split
42
  from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
43
 
44
  # ---------------------------
45
- # Config / Hyperparams (kept same as in your scripts)
46
  # ---------------------------
47
  P_MASK = 0.15
48
  MAX_ATOMIC_Z = 85
@@ -53,19 +48,19 @@ NODE_EMB_DIM = 300
53
  EDGE_EMB_DIM = 300
54
  NUM_GNN_LAYERS = 5
55
 
56
- # SchNet params (from your file)
57
  SCHNET_NUM_GAUSSIANS = 50
58
  SCHNET_NUM_INTERACTIONS = 6
59
  SCHNET_CUTOFF = 10.0
60
  SCHNET_MAX_NEIGHBORS = 64
61
  SCHNET_HIDDEN = 600
62
 
63
- # Fingerprint (MLM) params
64
  FP_LENGTH = 2048
65
- MASK_TOKEN_ID_FP = 2 # consistent with your fingerprint file
66
  VOCAB_SIZE_FP = 3
67
 
68
- # PSMILES/Deberta params (from your file)
69
  DEBERTA_HIDDEN = 600
70
  PSMILES_MAX_LEN = 128
71
 
@@ -73,15 +68,15 @@ PSMILES_MAX_LEN = 128
73
  TEMPERATURE = 0.07
74
 
75
  # Reconstruction loss weight (balance between contrastive and reconstruction objectives)
76
- REC_LOSS_WEIGHT = 1.0 # you can tune this (e.g., 0.5, 1.0)
77
 
78
- # Training args (same across files)
79
- OUTPUT_DIR = "./multimodal_output"
80
  os.makedirs(OUTPUT_DIR, exist_ok=True)
81
- BEST_GINE_DIR = "./gin_output/best"
82
- BEST_SCHNET_DIR = "./schnet_output/best"
83
- BEST_FP_DIR = "./fingerprint_mlm_output/best"
84
- BEST_PSMILES_DIR = "./polybert_output/best"
85
 
86
  training_args = TrainingArguments(
87
  output_dir=OUTPUT_DIR,
@@ -127,52 +122,7 @@ if USE_CUDA:
127
  # Utility / small helpers
128
  # ---------------------------
129
 
130
- def safe_get(d: dict, key: str, default=None):
131
- return d[key] if (isinstance(d, dict) and key in d) else default
132
-
133
-
134
- def match_edge_attr_to_index(edge_index: torch.Tensor, edge_attr: torch.Tensor, target_dim: int = 3):
135
- # determine device to allocate zero tensors on
136
- dev = None
137
- if edge_attr is not None and hasattr(edge_attr, "device"):
138
- dev = edge_attr.device
139
- elif edge_index is not None and hasattr(edge_index, "device"):
140
- dev = edge_index.device
141
- else:
142
- dev = torch.device("cpu")
143
-
144
- if edge_index is None or edge_index.numel() == 0:
145
- return torch.zeros((0, target_dim), dtype=torch.float, device=dev)
146
- E_idx = edge_index.size(1)
147
- if edge_attr is None or edge_attr.numel() == 0:
148
- return torch.zeros((E_idx, target_dim), dtype=torch.float, device=dev)
149
- E_attr = edge_attr.size(0)
150
- if E_attr == E_idx:
151
- if edge_attr.size(1) != target_dim:
152
- D = edge_attr.size(1)
153
- if D < target_dim:
154
- pad = torch.zeros((E_attr, target_dim - D), dtype=torch.float, device=edge_attr.device)
155
- return torch.cat([edge_attr, pad], dim=1)
156
- else:
157
- return edge_attr[:, :target_dim]
158
- return edge_attr
159
- if E_attr * 2 == E_idx:
160
- try:
161
- return torch.cat([edge_attr, edge_attr], dim=0)
162
- except Exception:
163
- pass
164
- reps = (E_idx + E_attr - 1) // E_attr
165
- edge_rep = edge_attr.repeat(reps, 1)[:E_idx]
166
- if edge_rep.size(1) != target_dim:
167
- D = edge_rep.size(1)
168
- if D < target_dim:
169
- pad = torch.zeros((E_idx, target_dim - D), dtype=torch.float, device=edge_rep.device)
170
- edge_rep = torch.cat([edge_rep, pad], dim=1)
171
- else:
172
- edge_rep = edge_rep[:, :target_dim]
173
- return edge_rep
174
-
175
- # Optimized BFS to compute distances to visible anchors (used if needed)
176
  def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_idx: np.ndarray, visible_idx: np.ndarray, k_anchors: int):
177
  INF = num_nodes + 1
178
  selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
@@ -213,13 +163,13 @@ def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_id
213
  return selected_dists, selected_mask
214
 
215
  # ---------------------------
216
- # Data loading / preprocessing (streaming to disk to avoid memory spike)
217
  # ---------------------------
218
- CSV_PATH = "Polymer_Foundational_Model/polymer_structures_unified_processed.csv"
219
  TARGET_ROWS = 2000000
220
  CHUNKSIZE = 50000
221
 
222
- PREPROC_DIR = "preprocessed_samples"
223
  os.makedirs(PREPROC_DIR, exist_ok=True)
224
 
225
  # The per-sample file format: torch.save(sample_dict, sample_path)
@@ -240,7 +190,6 @@ def prepare_or_load_data_streaming():
240
  rows_read = 0
241
  sample_idx = 0
242
 
243
- # We'll parse CSV in chunks and for each row, if it contains all modalities, write sample to disk
244
  for chunk in pd.read_csv(CSV_PATH, engine="python", chunksize=CHUNKSIZE):
245
  # Pre-extract columns presence
246
  has_graph = "graph" in chunk.columns
@@ -446,48 +395,10 @@ def prepare_or_load_data_streaming():
446
  sample_files = prepare_or_load_data_streaming()
447
 
448
  # ---------------------------
449
- # Prepare tokenizer for psmiles (deferred, but we still attempt HF tokenizer; fallback created)
450
  # ---------------------------
451
- try:
452
- SPM_MODEL = "spm.model"
453
- if Path(SPM_MODEL).exists():
454
- tokenizer = DebertaV2Tokenizer(vocab_file=SPM_MODEL, do_lower_case=False)
455
- tokenizer.add_special_tokens({"pad_token": "<pad>", "mask_token": "<mask>"})
456
- tokenizer.pad_token = "<pad>"
457
- tokenizer.mask_token = "<mask>"
458
- else:
459
- tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v2-xlarge", use_fast=False)
460
- tokenizer.add_special_tokens({"pad_token": "<pad>", "mask_token": "<mask>"})
461
- tokenizer.pad_token = "<pad>"
462
- tokenizer.mask_token = "<mask>"
463
- except Exception as e:
464
- print("Warning: Deberta tokenizer creation failed:", e)
465
- # create a simple fallback tokenizer (char-level)
466
- class SimplePSMILESTokenizer:
467
- def __init__(self, max_len=PSMILES_MAX_LEN):
468
- chars = list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-=#()[]@+/\\.")
469
- self.vocab = {c: i + 5 for i, c in enumerate(chars)}
470
- self.vocab["<pad>"] = 0
471
- self.vocab["<mask>"] = 1
472
- self.vocab["<unk>"] = 2
473
- self.vocab["<cls>"] = 3
474
- self.vocab["<sep>"] = 4
475
- self.mask_token = "<mask>"
476
- self.mask_token_id = self.vocab[self.mask_token]
477
- self.vocab_size = len(self.vocab)
478
- self.max_len = max_len
479
-
480
- def __call__(self, s, truncation=True, padding="max_length", max_length=None):
481
- max_len = max_length or self.max_len
482
- toks = [self.vocab.get(ch, self.vocab["<unk>"]) for ch in list(s)][:max_len]
483
- attn = [1] * len(toks)
484
- if len(toks) < max_len:
485
- pad = [self.vocab["<pad>"]] * (max_len - len(toks))
486
- toks = toks + pad
487
- attn = attn + [0] * (max_len - len(attn))
488
- return {"input_ids": toks, "attention_mask": attn}
489
-
490
- tokenizer = SimplePSMILESTokenizer()
491
 
492
  # ---------------------------
493
  # Lazy dataset: loads per-sample file on demand and tokenizes psmiles on-the-fly
@@ -672,313 +583,7 @@ train_loader = DataLoader(train_subset, batch_size=training_args.per_device_trai
672
  val_loader = DataLoader(val_subset, batch_size=training_args.per_device_eval_batch_size, shuffle=False, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
673
 
674
  # ---------------------------
675
- # Encoder definitions (kept same as original with minimal device-safe guards)
676
- # ---------------------------
677
-
678
- class GineBlock(nn.Module):
679
- def __init__(self, node_dim):
680
- super().__init__()
681
- self.mlp = nn.Sequential(
682
- nn.Linear(node_dim, node_dim),
683
- nn.ReLU(),
684
- nn.Linear(node_dim, node_dim)
685
- )
686
- # If GINEConv is not available, we still construct placeholder to fail later with message
687
- if GINEConv is None:
688
- raise RuntimeError("GINEConv is not available. Install torch_geometric with compatible versions.")
689
- self.conv = GINEConv(self.mlp)
690
- self.bn = nn.BatchNorm1d(node_dim)
691
- self.act = nn.ReLU()
692
-
693
- def forward(self, x, edge_index, edge_attr):
694
- x = self.conv(x, edge_index, edge_attr)
695
- x = self.bn(x)
696
- x = self.act(x)
697
- return x
698
-
699
- class GineEncoder(nn.Module):
700
- def __init__(self, node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z):
701
- super().__init__()
702
- self.atom_emb = nn.Embedding(num_embeddings=MASK_ATOM_ID+1, embedding_dim=node_emb_dim, padding_idx=None)
703
- self.node_attr_proj = nn.Sequential(
704
- nn.Linear(2, node_emb_dim),
705
- nn.ReLU(),
706
- nn.Linear(node_emb_dim, node_emb_dim)
707
- )
708
- self.edge_encoder = nn.Sequential(
709
- nn.Linear(3, edge_emb_dim),
710
- nn.ReLU(),
711
- nn.Linear(edge_emb_dim, edge_emb_dim)
712
- )
713
- if edge_emb_dim != node_emb_dim:
714
- self._edge_to_node_proj = nn.Linear(edge_emb_dim, node_emb_dim)
715
- else:
716
- self._edge_to_node_proj = None
717
- self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
718
- # global pooling projection
719
- self.pool_proj = nn.Linear(node_emb_dim, node_emb_dim)
720
-
721
- # node-level classifier head for reconstructing atomic ids if needed
722
- self.node_classifier = nn.Linear(node_emb_dim, MASK_ATOM_ID+1)
723
-
724
- def _compute_node_reps(self, z, chirality, formal_charge, edge_index, edge_attr):
725
- device = next(self.parameters()).device
726
- atom_embedding = self.atom_emb(z.to(device))
727
- if chirality is None or formal_charge is None:
728
- node_attr = torch.zeros((z.size(0), 2), device=device)
729
- else:
730
- node_attr = torch.stack([chirality, formal_charge], dim=1).to(atom_embedding.device)
731
- node_attr_emb = self.node_attr_proj(node_attr)
732
- x = atom_embedding + node_attr_emb
733
- if edge_attr is None or edge_attr.numel() == 0:
734
- edge_emb = torch.zeros((0, EDGE_EMB_DIM), dtype=torch.float, device=x.device)
735
- else:
736
- edge_emb = self.edge_encoder(edge_attr.to(x.device))
737
- if self._edge_to_node_proj is not None and edge_emb.numel() > 0:
738
- edge_for_conv = self._edge_to_node_proj(edge_emb)
739
- else:
740
- edge_for_conv = edge_emb
741
-
742
- h = x
743
- for layer in self.gnn_layers:
744
- h = layer(h, edge_index.to(h.device), edge_for_conv)
745
- return h
746
-
747
- def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None):
748
- h = self._compute_node_reps(z, chirality, formal_charge, edge_index, edge_attr)
749
- if batch is None:
750
- pooled = torch.mean(h, dim=0, keepdim=True)
751
- else:
752
- bsize = int(batch.max().item() + 1) if batch.numel() > 0 else 1
753
- pooled = torch.zeros((bsize, h.size(1)), device=h.device)
754
- for i in range(bsize):
755
- mask = batch == i
756
- if mask.sum() == 0:
757
- continue
758
- pooled[i] = h[mask].mean(dim=0)
759
- return self.pool_proj(pooled)
760
-
761
- def node_logits(self, z, chirality, formal_charge, edge_index, edge_attr):
762
- h = self._compute_node_reps(z, chirality, formal_charge, edge_index, edge_attr)
763
- logits = self.node_classifier(h)
764
- return logits
765
-
766
- class NodeSchNetWrapper(nn.Module):
767
- def __init__(self, hidden_channels=SCHNET_HIDDEN, num_interactions=SCHNET_NUM_INTERACTIONS, num_gaussians=SCHNET_NUM_GAUSSIANS, cutoff=SCHNET_CUTOFF, max_num_neighbors=SCHNET_MAX_NEIGHBORS):
768
- super().__init__()
769
- if PyGSchNet is None:
770
- raise RuntimeError("PyG SchNet is not available. Install torch_geometric with compatible extras.")
771
- self.schnet = PyGSchNet(hidden_channels=hidden_channels, num_filters=hidden_channels, num_interactions=num_interactions, num_gaussians=num_gaussians, cutoff=cutoff, max_num_neighbors=max_num_neighbors)
772
- self.pool_proj = nn.Linear(hidden_channels, hidden_channels)
773
- self.cutoff = cutoff
774
- self.max_num_neighbors = max_num_neighbors
775
- self.node_classifier = nn.Linear(hidden_channels, MASK_ATOM_ID+1)
776
-
777
- def forward(self, z, pos, batch=None):
778
- device = next(self.parameters()).device
779
- z = z.to(device)
780
- pos = pos.to(device)
781
- if batch is None:
782
- batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
783
- try:
784
- edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors)
785
- except Exception:
786
- edge_index = None
787
-
788
- node_h = None
789
- try:
790
- if hasattr(self.schnet, "embedding"):
791
- node_h = self.schnet.embedding(z)
792
- else:
793
- node_h = self.schnet.embedding(z)
794
- except Exception:
795
- node_h = None
796
-
797
- if node_h is not None and edge_index is not None and edge_index.numel() > 0:
798
- row, col = edge_index
799
- edge_weight = (pos[row] - pos[col]).norm(dim=-1)
800
- edge_attr = None
801
- if hasattr(self.schnet, "distance_expansion"):
802
- try:
803
- edge_attr = self.schnet.distance_expansion(edge_weight)
804
- except Exception:
805
- edge_attr = None
806
- if edge_attr is None and hasattr(self.schnet, "gaussian_smearing"):
807
- try:
808
- edge_attr = self.schnet.gaussian_smearing(edge_weight)
809
- except Exception:
810
- edge_attr = None
811
- if hasattr(self.schnet, "interactions") and getattr(self.schnet, "interactions") is not None:
812
- for interaction in self.schnet.interactions:
813
- try:
814
- node_h = node_h + interaction(node_h, edge_index, edge_weight, edge_attr)
815
- except TypeError:
816
- node_h = node_h + interaction(node_h, edge_index, edge_weight)
817
- if node_h is None:
818
- try:
819
- out = self.schnet(z=z, pos=pos, batch=batch)
820
- if isinstance(out, torch.Tensor) and out.dim() == 2 and out.size(0) == z.size(0):
821
- node_h = out
822
- elif hasattr(out, "last_hidden_state"):
823
- node_h = out.last_hidden_state
824
- elif isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):
825
- cand = out[0]
826
- if cand.dim() == 2 and cand.size(0) == z.size(0):
827
- node_h = cand
828
- except Exception as e:
829
- raise RuntimeError("Failed to obtain node-level embeddings from PyG SchNet.") from e
830
-
831
- bsize = int(batch.max().item()) + 1 if z.numel() > 0 else 1
832
- pooled = torch.zeros((bsize, node_h.size(1)), device=node_h.device)
833
- for i in range(bsize):
834
- mask = batch == i
835
- if mask.sum() == 0:
836
- continue
837
- pooled[i] = node_h[mask].mean(dim=0)
838
- return self.pool_proj(pooled)
839
-
840
- def node_logits(self, z, pos, batch=None):
841
- device = next(self.parameters()).device
842
- z = z.to(device)
843
- pos = pos.to(device)
844
- if batch is None:
845
- batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
846
- try:
847
- edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors)
848
- except Exception:
849
- edge_index = None
850
-
851
- node_h = None
852
- try:
853
- if hasattr(self.schnet, "embedding"):
854
- node_h = self.schnet.embedding(z)
855
- except Exception:
856
- node_h = None
857
-
858
- if node_h is not None and edge_index is not None and edge_index.numel() > 0:
859
- row, col = edge_index
860
- edge_weight = (pos[row] - pos[col]).norm(dim=-1)
861
- edge_attr = None
862
- if hasattr(self.schnet, "distance_expansion"):
863
- try:
864
- edge_attr = self.schnet.distance_expansion(edge_weight)
865
- except Exception:
866
- edge_attr = None
867
- if edge_attr is None and hasattr(self.schnet, "gaussian_smearing"):
868
- try:
869
- edge_attr = self.schnet.gaussian_smearing(edge_weight)
870
- except Exception:
871
- edge_attr = None
872
- if hasattr(self.schnet, "interactions") and getattr(self.schnet, "interactions") is not None:
873
- for interaction in self.schnet.interactions:
874
- try:
875
- node_h = node_h + interaction(node_h, edge_index, edge_weight, edge_attr)
876
- except TypeError:
877
- node_h = node_h + interaction(node_h, edge_index, edge_weight)
878
-
879
- if node_h is None:
880
- out = self.schnet(z=z, pos=pos, batch=batch)
881
- if isinstance(out, torch.Tensor):
882
- node_h = out
883
- elif hasattr(out, "last_hidden_state"):
884
- node_h = out.last_hidden_state
885
- elif isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):
886
- node_h = out[0]
887
- else:
888
- raise RuntimeError("Unable to obtain node embeddings for SchNet node_logits")
889
-
890
- logits = self.node_classifier(node_h)
891
- return logits
892
-
893
- class FingerprintEncoder(nn.Module):
894
- def __init__(self, vocab_size=VOCAB_SIZE_FP, hidden_dim=256, seq_len=FP_LENGTH, num_layers=4, nhead=8, dim_feedforward=1024, dropout=0.1):
895
- super().__init__()
896
- self.token_emb = nn.Embedding(vocab_size, hidden_dim)
897
- self.pos_emb = nn.Embedding(seq_len, hidden_dim)
898
- encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
899
- self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
900
- self.pool_proj = nn.Linear(hidden_dim, hidden_dim)
901
- self.seq_len = seq_len
902
- self.token_proj = nn.Linear(hidden_dim, vocab_size)
903
-
904
- def forward(self, input_ids, attention_mask=None):
905
- device = next(self.parameters()).device
906
- input_ids = input_ids.to(device)
907
- B, L = input_ids.shape
908
- x = self.token_emb(input_ids)
909
- pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
910
- x = x + self.pos_emb(pos_ids)
911
- if attention_mask is not None:
912
- key_padding_mask = ~attention_mask.to(input_ids.device)
913
- else:
914
- key_padding_mask = None
915
- out = self.transformer(x, src_key_padding_mask=key_padding_mask)
916
- if attention_mask is None:
917
- pooled = out.mean(dim=1)
918
- else:
919
- am = attention_mask.to(out.device).float().unsqueeze(-1)
920
- pooled = (out * am).sum(dim=1) / (am.sum(dim=1).clamp(min=1.0))
921
- return self.pool_proj(pooled)
922
-
923
- def token_logits(self, input_ids, attention_mask=None):
924
- device = next(self.parameters()).device
925
- input_ids = input_ids.to(device)
926
- B, L = input_ids.shape
927
- x = self.token_emb(input_ids)
928
- pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
929
- x = x + self.pos_emb(pos_ids)
930
- if attention_mask is not None:
931
- key_padding_mask = ~attention_mask.to(input_ids.device)
932
- else:
933
- key_padding_mask = None
934
- out = self.transformer(x, src_key_padding_mask=key_padding_mask)
935
- logits = self.token_proj(out)
936
- return logits
937
-
938
- class PSMILESDebertaEncoder(nn.Module):
939
- def __init__(self, model_dir_or_name: Optional[str] = None):
940
- super().__init__()
941
- try:
942
- if model_dir_or_name is not None and os.path.isdir(model_dir_or_name):
943
- self.model = DebertaV2ForMaskedLM.from_pretrained(model_dir_or_name)
944
- else:
945
- self.model = DebertaV2ForMaskedLM.from_pretrained("microsoft/deberta-v2-xlarge")
946
- except Exception as e:
947
- print("Warning: couldn't load DebertaV2 pretrained weights; initializing randomly.", e)
948
- from transformers import DebertaV2Config
949
- cfg = DebertaV2Config(vocab_size=getattr(tokenizer, "vocab_size", 300), hidden_size=DEBERTA_HIDDEN, num_attention_heads=12, num_hidden_layers=12, intermediate_size=512)
950
- self.model = DebertaV2ForMaskedLM(cfg)
951
- self.pool_proj = nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size)
952
-
953
- def forward(self, input_ids, attention_mask=None):
954
- device = next(self.parameters()).device
955
- input_ids = input_ids.to(device)
956
- if attention_mask is not None:
957
- attention_mask = attention_mask.to(device)
958
- outputs = self.model.base_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
959
- last_hidden = outputs.last_hidden_state
960
- if attention_mask is None:
961
- pooled = last_hidden.mean(dim=1)
962
- else:
963
- am = attention_mask.unsqueeze(-1).to(last_hidden.device).float()
964
- pooled = (last_hidden * am).sum(dim=1) / (am.sum(dim=1).clamp(min=1.0))
965
- return self.pool_proj(pooled)
966
-
967
- def token_logits(self, input_ids, attention_mask=None, labels=None):
968
- device = next(self.parameters()).device
969
- input_ids = input_ids.to(device)
970
- if attention_mask is not None:
971
- attention_mask = attention_mask.to(device)
972
- if labels is not None:
973
- labels = labels.to(device)
974
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)
975
- return outputs.loss
976
- else:
977
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
978
- return outputs.logits
979
-
980
- # ---------------------------
981
- # Multimodal wrapper & loss (kept same)
982
  # ---------------------------
983
  class MultimodalContrastiveModel(nn.Module):
984
  def __init__(self,
 
22
  import torch.nn.functional as F
23
  from torch.utils.data import Dataset, DataLoader
24
 
25
+ # Shared model utilities
26
+ from GINE import GineEncoder, match_edge_attr_to_index, safe_get
27
+ from SchNet import NodeSchNetWrapper
28
+ from Transformer import PooledFingerprintEncoder as FingerprintEncoder
29
+ from DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
 
 
 
 
 
30
 
31
  # HF Trainer & Transformers
32
+ from transformers import TrainingArguments, Trainer
33
  from transformers import DataCollatorForLanguageModeling
34
  from transformers.trainer_callback import TrainerCallback
35
 
 
37
  from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
38
 
39
  # ---------------------------
40
+ # Config / Hyperparams (paths are placeholders; update for your environment)
41
  # ---------------------------
42
  P_MASK = 0.15
43
  MAX_ATOMIC_Z = 85
 
48
  EDGE_EMB_DIM = 300
49
  NUM_GNN_LAYERS = 5
50
 
51
+ # SchNet params
52
  SCHNET_NUM_GAUSSIANS = 50
53
  SCHNET_NUM_INTERACTIONS = 6
54
  SCHNET_CUTOFF = 10.0
55
  SCHNET_MAX_NEIGHBORS = 64
56
  SCHNET_HIDDEN = 600
57
 
58
+ # Transformer params
59
  FP_LENGTH = 2048
60
+ MASK_TOKEN_ID_FP = 2
61
  VOCAB_SIZE_FP = 3
62
 
63
+ # DeBERTav2 params
64
  DEBERTA_HIDDEN = 600
65
  PSMILES_MAX_LEN = 128
66
 
 
68
  TEMPERATURE = 0.07
69
 
70
  # Reconstruction loss weight (balance between contrastive and reconstruction objectives)
71
+ REC_LOSS_WEIGHT = 1.0 # could be tuned (e.g., 0.5, 1.0)
72
 
73
+ # Training args
74
+ OUTPUT_DIR = "/path/to/multimodal_output"
75
  os.makedirs(OUTPUT_DIR, exist_ok=True)
76
+ BEST_GINE_DIR = "/path/to/gin_output/best"
77
+ BEST_SCHNET_DIR = "/path/to/schnet_output/best"
78
+ BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
79
+ BEST_PSMILES_DIR = "/path/to/polybert_output/best"
80
 
81
  training_args = TrainingArguments(
82
  output_dir=OUTPUT_DIR,
 
122
  # Utility / small helpers
123
  # ---------------------------
124
 
125
+ # Optimized BFS to compute distances to visible anchors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_idx: np.ndarray, visible_idx: np.ndarray, k_anchors: int):
127
  INF = num_nodes + 1
128
  selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
 
163
  return selected_dists, selected_mask
164
 
165
  # ---------------------------
166
+ # Data loading / preprocessing
167
  # ---------------------------
168
+ CSV_PATH = "/path/to/polymer_structures_unified_processed.csv"
169
  TARGET_ROWS = 2000000
170
  CHUNKSIZE = 50000
171
 
172
+ PREPROC_DIR = "/path/to/preprocessed_samples"
173
  os.makedirs(PREPROC_DIR, exist_ok=True)
174
 
175
  # The per-sample file format: torch.save(sample_dict, sample_path)
 
190
  rows_read = 0
191
  sample_idx = 0
192
 
 
193
  for chunk in pd.read_csv(CSV_PATH, engine="python", chunksize=CHUNKSIZE):
194
  # Pre-extract columns presence
195
  has_graph = "graph" in chunk.columns
 
395
  sample_files = prepare_or_load_data_streaming()
396
 
397
  # ---------------------------
398
+ # Prepare tokenizer for psmiles
399
  # ---------------------------
400
+ SPM_MODEL = "/path/to/spm.model"
401
+ tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  # ---------------------------
404
  # Lazy dataset: loads per-sample file on demand and tokenizes psmiles on-the-fly
 
583
  val_loader = DataLoader(val_subset, batch_size=training_args.per_device_eval_batch_size, shuffle=False, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
584
 
585
  # ---------------------------
586
+ # Multimodal wrapper & loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  # ---------------------------
588
  class MultimodalContrastiveModel(nn.Module):
589
  def __init__(self,