Spaces:
Running
Running
manpreet88 commited on
Commit ·
6a3741a
1
Parent(s): 949da3f
Update CL.py
Browse files- PolyFusion/CL.py +26 -421
PolyFusion/CL.py
CHANGED
|
@@ -22,19 +22,14 @@ import torch.nn as nn
|
|
| 22 |
import torch.nn.functional as F
|
| 23 |
from torch.utils.data import Dataset, DataLoader
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
except Exception as e:
|
| 31 |
-
# we keep imports guarded — if these fail, user will get a clear error later
|
| 32 |
-
GINEConv = None
|
| 33 |
-
PyGSchNet = None
|
| 34 |
-
radius_graph = None
|
| 35 |
|
| 36 |
# HF Trainer & Transformers
|
| 37 |
-
from transformers import TrainingArguments, Trainer
|
| 38 |
from transformers import DataCollatorForLanguageModeling
|
| 39 |
from transformers.trainer_callback import TrainerCallback
|
| 40 |
|
|
@@ -42,7 +37,7 @@ from sklearn.model_selection import train_test_split
|
|
| 42 |
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
|
| 43 |
|
| 44 |
# ---------------------------
|
| 45 |
-
# Config / Hyperparams (
|
| 46 |
# ---------------------------
|
| 47 |
P_MASK = 0.15
|
| 48 |
MAX_ATOMIC_Z = 85
|
|
@@ -53,19 +48,19 @@ NODE_EMB_DIM = 300
|
|
| 53 |
EDGE_EMB_DIM = 300
|
| 54 |
NUM_GNN_LAYERS = 5
|
| 55 |
|
| 56 |
-
# SchNet params
|
| 57 |
SCHNET_NUM_GAUSSIANS = 50
|
| 58 |
SCHNET_NUM_INTERACTIONS = 6
|
| 59 |
SCHNET_CUTOFF = 10.0
|
| 60 |
SCHNET_MAX_NEIGHBORS = 64
|
| 61 |
SCHNET_HIDDEN = 600
|
| 62 |
|
| 63 |
-
#
|
| 64 |
FP_LENGTH = 2048
|
| 65 |
-
MASK_TOKEN_ID_FP = 2
|
| 66 |
VOCAB_SIZE_FP = 3
|
| 67 |
|
| 68 |
-
#
|
| 69 |
DEBERTA_HIDDEN = 600
|
| 70 |
PSMILES_MAX_LEN = 128
|
| 71 |
|
|
@@ -73,15 +68,15 @@ PSMILES_MAX_LEN = 128
|
|
| 73 |
TEMPERATURE = 0.07
|
| 74 |
|
| 75 |
# Reconstruction loss weight (balance between contrastive and reconstruction objectives)
|
| 76 |
-
REC_LOSS_WEIGHT = 1.0 #
|
| 77 |
|
| 78 |
-
# Training args
|
| 79 |
-
OUTPUT_DIR = "
|
| 80 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 81 |
-
BEST_GINE_DIR = "
|
| 82 |
-
BEST_SCHNET_DIR = "
|
| 83 |
-
BEST_FP_DIR = "
|
| 84 |
-
BEST_PSMILES_DIR = "
|
| 85 |
|
| 86 |
training_args = TrainingArguments(
|
| 87 |
output_dir=OUTPUT_DIR,
|
|
@@ -127,52 +122,7 @@ if USE_CUDA:
|
|
| 127 |
# Utility / small helpers
|
| 128 |
# ---------------------------
|
| 129 |
|
| 130 |
-
|
| 131 |
-
return d[key] if (isinstance(d, dict) and key in d) else default
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def match_edge_attr_to_index(edge_index: torch.Tensor, edge_attr: torch.Tensor, target_dim: int = 3):
|
| 135 |
-
# determine device to allocate zero tensors on
|
| 136 |
-
dev = None
|
| 137 |
-
if edge_attr is not None and hasattr(edge_attr, "device"):
|
| 138 |
-
dev = edge_attr.device
|
| 139 |
-
elif edge_index is not None and hasattr(edge_index, "device"):
|
| 140 |
-
dev = edge_index.device
|
| 141 |
-
else:
|
| 142 |
-
dev = torch.device("cpu")
|
| 143 |
-
|
| 144 |
-
if edge_index is None or edge_index.numel() == 0:
|
| 145 |
-
return torch.zeros((0, target_dim), dtype=torch.float, device=dev)
|
| 146 |
-
E_idx = edge_index.size(1)
|
| 147 |
-
if edge_attr is None or edge_attr.numel() == 0:
|
| 148 |
-
return torch.zeros((E_idx, target_dim), dtype=torch.float, device=dev)
|
| 149 |
-
E_attr = edge_attr.size(0)
|
| 150 |
-
if E_attr == E_idx:
|
| 151 |
-
if edge_attr.size(1) != target_dim:
|
| 152 |
-
D = edge_attr.size(1)
|
| 153 |
-
if D < target_dim:
|
| 154 |
-
pad = torch.zeros((E_attr, target_dim - D), dtype=torch.float, device=edge_attr.device)
|
| 155 |
-
return torch.cat([edge_attr, pad], dim=1)
|
| 156 |
-
else:
|
| 157 |
-
return edge_attr[:, :target_dim]
|
| 158 |
-
return edge_attr
|
| 159 |
-
if E_attr * 2 == E_idx:
|
| 160 |
-
try:
|
| 161 |
-
return torch.cat([edge_attr, edge_attr], dim=0)
|
| 162 |
-
except Exception:
|
| 163 |
-
pass
|
| 164 |
-
reps = (E_idx + E_attr - 1) // E_attr
|
| 165 |
-
edge_rep = edge_attr.repeat(reps, 1)[:E_idx]
|
| 166 |
-
if edge_rep.size(1) != target_dim:
|
| 167 |
-
D = edge_rep.size(1)
|
| 168 |
-
if D < target_dim:
|
| 169 |
-
pad = torch.zeros((E_idx, target_dim - D), dtype=torch.float, device=edge_rep.device)
|
| 170 |
-
edge_rep = torch.cat([edge_rep, pad], dim=1)
|
| 171 |
-
else:
|
| 172 |
-
edge_rep = edge_rep[:, :target_dim]
|
| 173 |
-
return edge_rep
|
| 174 |
-
|
| 175 |
-
# Optimized BFS to compute distances to visible anchors (used if needed)
|
| 176 |
def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_idx: np.ndarray, visible_idx: np.ndarray, k_anchors: int):
|
| 177 |
INF = num_nodes + 1
|
| 178 |
selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
|
|
@@ -213,13 +163,13 @@ def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_id
|
|
| 213 |
return selected_dists, selected_mask
|
| 214 |
|
| 215 |
# ---------------------------
|
| 216 |
-
# Data loading / preprocessing
|
| 217 |
# ---------------------------
|
| 218 |
-
CSV_PATH = "
|
| 219 |
TARGET_ROWS = 2000000
|
| 220 |
CHUNKSIZE = 50000
|
| 221 |
|
| 222 |
-
PREPROC_DIR = "preprocessed_samples"
|
| 223 |
os.makedirs(PREPROC_DIR, exist_ok=True)
|
| 224 |
|
| 225 |
# The per-sample file format: torch.save(sample_dict, sample_path)
|
|
@@ -240,7 +190,6 @@ def prepare_or_load_data_streaming():
|
|
| 240 |
rows_read = 0
|
| 241 |
sample_idx = 0
|
| 242 |
|
| 243 |
-
# We'll parse CSV in chunks and for each row, if it contains all modalities, write sample to disk
|
| 244 |
for chunk in pd.read_csv(CSV_PATH, engine="python", chunksize=CHUNKSIZE):
|
| 245 |
# Pre-extract columns presence
|
| 246 |
has_graph = "graph" in chunk.columns
|
|
@@ -446,48 +395,10 @@ def prepare_or_load_data_streaming():
|
|
| 446 |
sample_files = prepare_or_load_data_streaming()
|
| 447 |
|
| 448 |
# ---------------------------
|
| 449 |
-
# Prepare tokenizer for psmiles
|
| 450 |
# ---------------------------
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
if Path(SPM_MODEL).exists():
|
| 454 |
-
tokenizer = DebertaV2Tokenizer(vocab_file=SPM_MODEL, do_lower_case=False)
|
| 455 |
-
tokenizer.add_special_tokens({"pad_token": "<pad>", "mask_token": "<mask>"})
|
| 456 |
-
tokenizer.pad_token = "<pad>"
|
| 457 |
-
tokenizer.mask_token = "<mask>"
|
| 458 |
-
else:
|
| 459 |
-
tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v2-xlarge", use_fast=False)
|
| 460 |
-
tokenizer.add_special_tokens({"pad_token": "<pad>", "mask_token": "<mask>"})
|
| 461 |
-
tokenizer.pad_token = "<pad>"
|
| 462 |
-
tokenizer.mask_token = "<mask>"
|
| 463 |
-
except Exception as e:
|
| 464 |
-
print("Warning: Deberta tokenizer creation failed:", e)
|
| 465 |
-
# create a simple fallback tokenizer (char-level)
|
| 466 |
-
class SimplePSMILESTokenizer:
|
| 467 |
-
def __init__(self, max_len=PSMILES_MAX_LEN):
|
| 468 |
-
chars = list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-=#()[]@+/\\.")
|
| 469 |
-
self.vocab = {c: i + 5 for i, c in enumerate(chars)}
|
| 470 |
-
self.vocab["<pad>"] = 0
|
| 471 |
-
self.vocab["<mask>"] = 1
|
| 472 |
-
self.vocab["<unk>"] = 2
|
| 473 |
-
self.vocab["<cls>"] = 3
|
| 474 |
-
self.vocab["<sep>"] = 4
|
| 475 |
-
self.mask_token = "<mask>"
|
| 476 |
-
self.mask_token_id = self.vocab[self.mask_token]
|
| 477 |
-
self.vocab_size = len(self.vocab)
|
| 478 |
-
self.max_len = max_len
|
| 479 |
-
|
| 480 |
-
def __call__(self, s, truncation=True, padding="max_length", max_length=None):
|
| 481 |
-
max_len = max_length or self.max_len
|
| 482 |
-
toks = [self.vocab.get(ch, self.vocab["<unk>"]) for ch in list(s)][:max_len]
|
| 483 |
-
attn = [1] * len(toks)
|
| 484 |
-
if len(toks) < max_len:
|
| 485 |
-
pad = [self.vocab["<pad>"]] * (max_len - len(toks))
|
| 486 |
-
toks = toks + pad
|
| 487 |
-
attn = attn + [0] * (max_len - len(attn))
|
| 488 |
-
return {"input_ids": toks, "attention_mask": attn}
|
| 489 |
-
|
| 490 |
-
tokenizer = SimplePSMILESTokenizer()
|
| 491 |
|
| 492 |
# ---------------------------
|
| 493 |
# Lazy dataset: loads per-sample file on demand and tokenizes psmiles on-the-fly
|
|
@@ -672,313 +583,7 @@ train_loader = DataLoader(train_subset, batch_size=training_args.per_device_trai
|
|
| 672 |
val_loader = DataLoader(val_subset, batch_size=training_args.per_device_eval_batch_size, shuffle=False, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
|
| 673 |
|
| 674 |
# ---------------------------
|
| 675 |
-
#
|
| 676 |
-
# ---------------------------
|
| 677 |
-
|
| 678 |
-
class GineBlock(nn.Module):
|
| 679 |
-
def __init__(self, node_dim):
|
| 680 |
-
super().__init__()
|
| 681 |
-
self.mlp = nn.Sequential(
|
| 682 |
-
nn.Linear(node_dim, node_dim),
|
| 683 |
-
nn.ReLU(),
|
| 684 |
-
nn.Linear(node_dim, node_dim)
|
| 685 |
-
)
|
| 686 |
-
# If GINEConv is not available, we still construct placeholder to fail later with message
|
| 687 |
-
if GINEConv is None:
|
| 688 |
-
raise RuntimeError("GINEConv is not available. Install torch_geometric with compatible versions.")
|
| 689 |
-
self.conv = GINEConv(self.mlp)
|
| 690 |
-
self.bn = nn.BatchNorm1d(node_dim)
|
| 691 |
-
self.act = nn.ReLU()
|
| 692 |
-
|
| 693 |
-
def forward(self, x, edge_index, edge_attr):
|
| 694 |
-
x = self.conv(x, edge_index, edge_attr)
|
| 695 |
-
x = self.bn(x)
|
| 696 |
-
x = self.act(x)
|
| 697 |
-
return x
|
| 698 |
-
|
| 699 |
-
class GineEncoder(nn.Module):
|
| 700 |
-
def __init__(self, node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z):
|
| 701 |
-
super().__init__()
|
| 702 |
-
self.atom_emb = nn.Embedding(num_embeddings=MASK_ATOM_ID+1, embedding_dim=node_emb_dim, padding_idx=None)
|
| 703 |
-
self.node_attr_proj = nn.Sequential(
|
| 704 |
-
nn.Linear(2, node_emb_dim),
|
| 705 |
-
nn.ReLU(),
|
| 706 |
-
nn.Linear(node_emb_dim, node_emb_dim)
|
| 707 |
-
)
|
| 708 |
-
self.edge_encoder = nn.Sequential(
|
| 709 |
-
nn.Linear(3, edge_emb_dim),
|
| 710 |
-
nn.ReLU(),
|
| 711 |
-
nn.Linear(edge_emb_dim, edge_emb_dim)
|
| 712 |
-
)
|
| 713 |
-
if edge_emb_dim != node_emb_dim:
|
| 714 |
-
self._edge_to_node_proj = nn.Linear(edge_emb_dim, node_emb_dim)
|
| 715 |
-
else:
|
| 716 |
-
self._edge_to_node_proj = None
|
| 717 |
-
self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
|
| 718 |
-
# global pooling projection
|
| 719 |
-
self.pool_proj = nn.Linear(node_emb_dim, node_emb_dim)
|
| 720 |
-
|
| 721 |
-
# node-level classifier head for reconstructing atomic ids if needed
|
| 722 |
-
self.node_classifier = nn.Linear(node_emb_dim, MASK_ATOM_ID+1)
|
| 723 |
-
|
| 724 |
-
def _compute_node_reps(self, z, chirality, formal_charge, edge_index, edge_attr):
|
| 725 |
-
device = next(self.parameters()).device
|
| 726 |
-
atom_embedding = self.atom_emb(z.to(device))
|
| 727 |
-
if chirality is None or formal_charge is None:
|
| 728 |
-
node_attr = torch.zeros((z.size(0), 2), device=device)
|
| 729 |
-
else:
|
| 730 |
-
node_attr = torch.stack([chirality, formal_charge], dim=1).to(atom_embedding.device)
|
| 731 |
-
node_attr_emb = self.node_attr_proj(node_attr)
|
| 732 |
-
x = atom_embedding + node_attr_emb
|
| 733 |
-
if edge_attr is None or edge_attr.numel() == 0:
|
| 734 |
-
edge_emb = torch.zeros((0, EDGE_EMB_DIM), dtype=torch.float, device=x.device)
|
| 735 |
-
else:
|
| 736 |
-
edge_emb = self.edge_encoder(edge_attr.to(x.device))
|
| 737 |
-
if self._edge_to_node_proj is not None and edge_emb.numel() > 0:
|
| 738 |
-
edge_for_conv = self._edge_to_node_proj(edge_emb)
|
| 739 |
-
else:
|
| 740 |
-
edge_for_conv = edge_emb
|
| 741 |
-
|
| 742 |
-
h = x
|
| 743 |
-
for layer in self.gnn_layers:
|
| 744 |
-
h = layer(h, edge_index.to(h.device), edge_for_conv)
|
| 745 |
-
return h
|
| 746 |
-
|
| 747 |
-
def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None):
|
| 748 |
-
h = self._compute_node_reps(z, chirality, formal_charge, edge_index, edge_attr)
|
| 749 |
-
if batch is None:
|
| 750 |
-
pooled = torch.mean(h, dim=0, keepdim=True)
|
| 751 |
-
else:
|
| 752 |
-
bsize = int(batch.max().item() + 1) if batch.numel() > 0 else 1
|
| 753 |
-
pooled = torch.zeros((bsize, h.size(1)), device=h.device)
|
| 754 |
-
for i in range(bsize):
|
| 755 |
-
mask = batch == i
|
| 756 |
-
if mask.sum() == 0:
|
| 757 |
-
continue
|
| 758 |
-
pooled[i] = h[mask].mean(dim=0)
|
| 759 |
-
return self.pool_proj(pooled)
|
| 760 |
-
|
| 761 |
-
def node_logits(self, z, chirality, formal_charge, edge_index, edge_attr):
|
| 762 |
-
h = self._compute_node_reps(z, chirality, formal_charge, edge_index, edge_attr)
|
| 763 |
-
logits = self.node_classifier(h)
|
| 764 |
-
return logits
|
| 765 |
-
|
| 766 |
-
class NodeSchNetWrapper(nn.Module):
|
| 767 |
-
def __init__(self, hidden_channels=SCHNET_HIDDEN, num_interactions=SCHNET_NUM_INTERACTIONS, num_gaussians=SCHNET_NUM_GAUSSIANS, cutoff=SCHNET_CUTOFF, max_num_neighbors=SCHNET_MAX_NEIGHBORS):
|
| 768 |
-
super().__init__()
|
| 769 |
-
if PyGSchNet is None:
|
| 770 |
-
raise RuntimeError("PyG SchNet is not available. Install torch_geometric with compatible extras.")
|
| 771 |
-
self.schnet = PyGSchNet(hidden_channels=hidden_channels, num_filters=hidden_channels, num_interactions=num_interactions, num_gaussians=num_gaussians, cutoff=cutoff, max_num_neighbors=max_num_neighbors)
|
| 772 |
-
self.pool_proj = nn.Linear(hidden_channels, hidden_channels)
|
| 773 |
-
self.cutoff = cutoff
|
| 774 |
-
self.max_num_neighbors = max_num_neighbors
|
| 775 |
-
self.node_classifier = nn.Linear(hidden_channels, MASK_ATOM_ID+1)
|
| 776 |
-
|
| 777 |
-
def forward(self, z, pos, batch=None):
|
| 778 |
-
device = next(self.parameters()).device
|
| 779 |
-
z = z.to(device)
|
| 780 |
-
pos = pos.to(device)
|
| 781 |
-
if batch is None:
|
| 782 |
-
batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
|
| 783 |
-
try:
|
| 784 |
-
edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors)
|
| 785 |
-
except Exception:
|
| 786 |
-
edge_index = None
|
| 787 |
-
|
| 788 |
-
node_h = None
|
| 789 |
-
try:
|
| 790 |
-
if hasattr(self.schnet, "embedding"):
|
| 791 |
-
node_h = self.schnet.embedding(z)
|
| 792 |
-
else:
|
| 793 |
-
node_h = self.schnet.embedding(z)
|
| 794 |
-
except Exception:
|
| 795 |
-
node_h = None
|
| 796 |
-
|
| 797 |
-
if node_h is not None and edge_index is not None and edge_index.numel() > 0:
|
| 798 |
-
row, col = edge_index
|
| 799 |
-
edge_weight = (pos[row] - pos[col]).norm(dim=-1)
|
| 800 |
-
edge_attr = None
|
| 801 |
-
if hasattr(self.schnet, "distance_expansion"):
|
| 802 |
-
try:
|
| 803 |
-
edge_attr = self.schnet.distance_expansion(edge_weight)
|
| 804 |
-
except Exception:
|
| 805 |
-
edge_attr = None
|
| 806 |
-
if edge_attr is None and hasattr(self.schnet, "gaussian_smearing"):
|
| 807 |
-
try:
|
| 808 |
-
edge_attr = self.schnet.gaussian_smearing(edge_weight)
|
| 809 |
-
except Exception:
|
| 810 |
-
edge_attr = None
|
| 811 |
-
if hasattr(self.schnet, "interactions") and getattr(self.schnet, "interactions") is not None:
|
| 812 |
-
for interaction in self.schnet.interactions:
|
| 813 |
-
try:
|
| 814 |
-
node_h = node_h + interaction(node_h, edge_index, edge_weight, edge_attr)
|
| 815 |
-
except TypeError:
|
| 816 |
-
node_h = node_h + interaction(node_h, edge_index, edge_weight)
|
| 817 |
-
if node_h is None:
|
| 818 |
-
try:
|
| 819 |
-
out = self.schnet(z=z, pos=pos, batch=batch)
|
| 820 |
-
if isinstance(out, torch.Tensor) and out.dim() == 2 and out.size(0) == z.size(0):
|
| 821 |
-
node_h = out
|
| 822 |
-
elif hasattr(out, "last_hidden_state"):
|
| 823 |
-
node_h = out.last_hidden_state
|
| 824 |
-
elif isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):
|
| 825 |
-
cand = out[0]
|
| 826 |
-
if cand.dim() == 2 and cand.size(0) == z.size(0):
|
| 827 |
-
node_h = cand
|
| 828 |
-
except Exception as e:
|
| 829 |
-
raise RuntimeError("Failed to obtain node-level embeddings from PyG SchNet.") from e
|
| 830 |
-
|
| 831 |
-
bsize = int(batch.max().item()) + 1 if z.numel() > 0 else 1
|
| 832 |
-
pooled = torch.zeros((bsize, node_h.size(1)), device=node_h.device)
|
| 833 |
-
for i in range(bsize):
|
| 834 |
-
mask = batch == i
|
| 835 |
-
if mask.sum() == 0:
|
| 836 |
-
continue
|
| 837 |
-
pooled[i] = node_h[mask].mean(dim=0)
|
| 838 |
-
return self.pool_proj(pooled)
|
| 839 |
-
|
| 840 |
-
def node_logits(self, z, pos, batch=None):
|
| 841 |
-
device = next(self.parameters()).device
|
| 842 |
-
z = z.to(device)
|
| 843 |
-
pos = pos.to(device)
|
| 844 |
-
if batch is None:
|
| 845 |
-
batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
|
| 846 |
-
try:
|
| 847 |
-
edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors)
|
| 848 |
-
except Exception:
|
| 849 |
-
edge_index = None
|
| 850 |
-
|
| 851 |
-
node_h = None
|
| 852 |
-
try:
|
| 853 |
-
if hasattr(self.schnet, "embedding"):
|
| 854 |
-
node_h = self.schnet.embedding(z)
|
| 855 |
-
except Exception:
|
| 856 |
-
node_h = None
|
| 857 |
-
|
| 858 |
-
if node_h is not None and edge_index is not None and edge_index.numel() > 0:
|
| 859 |
-
row, col = edge_index
|
| 860 |
-
edge_weight = (pos[row] - pos[col]).norm(dim=-1)
|
| 861 |
-
edge_attr = None
|
| 862 |
-
if hasattr(self.schnet, "distance_expansion"):
|
| 863 |
-
try:
|
| 864 |
-
edge_attr = self.schnet.distance_expansion(edge_weight)
|
| 865 |
-
except Exception:
|
| 866 |
-
edge_attr = None
|
| 867 |
-
if edge_attr is None and hasattr(self.schnet, "gaussian_smearing"):
|
| 868 |
-
try:
|
| 869 |
-
edge_attr = self.schnet.gaussian_smearing(edge_weight)
|
| 870 |
-
except Exception:
|
| 871 |
-
edge_attr = None
|
| 872 |
-
if hasattr(self.schnet, "interactions") and getattr(self.schnet, "interactions") is not None:
|
| 873 |
-
for interaction in self.schnet.interactions:
|
| 874 |
-
try:
|
| 875 |
-
node_h = node_h + interaction(node_h, edge_index, edge_weight, edge_attr)
|
| 876 |
-
except TypeError:
|
| 877 |
-
node_h = node_h + interaction(node_h, edge_index, edge_weight)
|
| 878 |
-
|
| 879 |
-
if node_h is None:
|
| 880 |
-
out = self.schnet(z=z, pos=pos, batch=batch)
|
| 881 |
-
if isinstance(out, torch.Tensor):
|
| 882 |
-
node_h = out
|
| 883 |
-
elif hasattr(out, "last_hidden_state"):
|
| 884 |
-
node_h = out.last_hidden_state
|
| 885 |
-
elif isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):
|
| 886 |
-
node_h = out[0]
|
| 887 |
-
else:
|
| 888 |
-
raise RuntimeError("Unable to obtain node embeddings for SchNet node_logits")
|
| 889 |
-
|
| 890 |
-
logits = self.node_classifier(node_h)
|
| 891 |
-
return logits
|
| 892 |
-
|
| 893 |
-
class FingerprintEncoder(nn.Module):
|
| 894 |
-
def __init__(self, vocab_size=VOCAB_SIZE_FP, hidden_dim=256, seq_len=FP_LENGTH, num_layers=4, nhead=8, dim_feedforward=1024, dropout=0.1):
|
| 895 |
-
super().__init__()
|
| 896 |
-
self.token_emb = nn.Embedding(vocab_size, hidden_dim)
|
| 897 |
-
self.pos_emb = nn.Embedding(seq_len, hidden_dim)
|
| 898 |
-
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
|
| 899 |
-
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 900 |
-
self.pool_proj = nn.Linear(hidden_dim, hidden_dim)
|
| 901 |
-
self.seq_len = seq_len
|
| 902 |
-
self.token_proj = nn.Linear(hidden_dim, vocab_size)
|
| 903 |
-
|
| 904 |
-
def forward(self, input_ids, attention_mask=None):
|
| 905 |
-
device = next(self.parameters()).device
|
| 906 |
-
input_ids = input_ids.to(device)
|
| 907 |
-
B, L = input_ids.shape
|
| 908 |
-
x = self.token_emb(input_ids)
|
| 909 |
-
pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
|
| 910 |
-
x = x + self.pos_emb(pos_ids)
|
| 911 |
-
if attention_mask is not None:
|
| 912 |
-
key_padding_mask = ~attention_mask.to(input_ids.device)
|
| 913 |
-
else:
|
| 914 |
-
key_padding_mask = None
|
| 915 |
-
out = self.transformer(x, src_key_padding_mask=key_padding_mask)
|
| 916 |
-
if attention_mask is None:
|
| 917 |
-
pooled = out.mean(dim=1)
|
| 918 |
-
else:
|
| 919 |
-
am = attention_mask.to(out.device).float().unsqueeze(-1)
|
| 920 |
-
pooled = (out * am).sum(dim=1) / (am.sum(dim=1).clamp(min=1.0))
|
| 921 |
-
return self.pool_proj(pooled)
|
| 922 |
-
|
| 923 |
-
def token_logits(self, input_ids, attention_mask=None):
|
| 924 |
-
device = next(self.parameters()).device
|
| 925 |
-
input_ids = input_ids.to(device)
|
| 926 |
-
B, L = input_ids.shape
|
| 927 |
-
x = self.token_emb(input_ids)
|
| 928 |
-
pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
|
| 929 |
-
x = x + self.pos_emb(pos_ids)
|
| 930 |
-
if attention_mask is not None:
|
| 931 |
-
key_padding_mask = ~attention_mask.to(input_ids.device)
|
| 932 |
-
else:
|
| 933 |
-
key_padding_mask = None
|
| 934 |
-
out = self.transformer(x, src_key_padding_mask=key_padding_mask)
|
| 935 |
-
logits = self.token_proj(out)
|
| 936 |
-
return logits
|
| 937 |
-
|
| 938 |
-
class PSMILESDebertaEncoder(nn.Module):
|
| 939 |
-
def __init__(self, model_dir_or_name: Optional[str] = None):
|
| 940 |
-
super().__init__()
|
| 941 |
-
try:
|
| 942 |
-
if model_dir_or_name is not None and os.path.isdir(model_dir_or_name):
|
| 943 |
-
self.model = DebertaV2ForMaskedLM.from_pretrained(model_dir_or_name)
|
| 944 |
-
else:
|
| 945 |
-
self.model = DebertaV2ForMaskedLM.from_pretrained("microsoft/deberta-v2-xlarge")
|
| 946 |
-
except Exception as e:
|
| 947 |
-
print("Warning: couldn't load DebertaV2 pretrained weights; initializing randomly.", e)
|
| 948 |
-
from transformers import DebertaV2Config
|
| 949 |
-
cfg = DebertaV2Config(vocab_size=getattr(tokenizer, "vocab_size", 300), hidden_size=DEBERTA_HIDDEN, num_attention_heads=12, num_hidden_layers=12, intermediate_size=512)
|
| 950 |
-
self.model = DebertaV2ForMaskedLM(cfg)
|
| 951 |
-
self.pool_proj = nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size)
|
| 952 |
-
|
| 953 |
-
def forward(self, input_ids, attention_mask=None):
|
| 954 |
-
device = next(self.parameters()).device
|
| 955 |
-
input_ids = input_ids.to(device)
|
| 956 |
-
if attention_mask is not None:
|
| 957 |
-
attention_mask = attention_mask.to(device)
|
| 958 |
-
outputs = self.model.base_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| 959 |
-
last_hidden = outputs.last_hidden_state
|
| 960 |
-
if attention_mask is None:
|
| 961 |
-
pooled = last_hidden.mean(dim=1)
|
| 962 |
-
else:
|
| 963 |
-
am = attention_mask.unsqueeze(-1).to(last_hidden.device).float()
|
| 964 |
-
pooled = (last_hidden * am).sum(dim=1) / (am.sum(dim=1).clamp(min=1.0))
|
| 965 |
-
return self.pool_proj(pooled)
|
| 966 |
-
|
| 967 |
-
def token_logits(self, input_ids, attention_mask=None, labels=None):
|
| 968 |
-
device = next(self.parameters()).device
|
| 969 |
-
input_ids = input_ids.to(device)
|
| 970 |
-
if attention_mask is not None:
|
| 971 |
-
attention_mask = attention_mask.to(device)
|
| 972 |
-
if labels is not None:
|
| 973 |
-
labels = labels.to(device)
|
| 974 |
-
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)
|
| 975 |
-
return outputs.loss
|
| 976 |
-
else:
|
| 977 |
-
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| 978 |
-
return outputs.logits
|
| 979 |
-
|
| 980 |
-
# ---------------------------
|
| 981 |
-
# Multimodal wrapper & loss (kept same)
|
| 982 |
# ---------------------------
|
| 983 |
class MultimodalContrastiveModel(nn.Module):
|
| 984 |
def __init__(self,
|
|
|
|
| 22 |
import torch.nn.functional as F
|
| 23 |
from torch.utils.data import Dataset, DataLoader
|
| 24 |
|
| 25 |
+
# Shared model utilities
|
| 26 |
+
from GINE import GineEncoder, match_edge_attr_to_index, safe_get
|
| 27 |
+
from SchNet import NodeSchNetWrapper
|
| 28 |
+
from Transformer import PooledFingerprintEncoder as FingerprintEncoder
|
| 29 |
+
from DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# HF Trainer & Transformers
|
| 32 |
+
from transformers import TrainingArguments, Trainer
|
| 33 |
from transformers import DataCollatorForLanguageModeling
|
| 34 |
from transformers.trainer_callback import TrainerCallback
|
| 35 |
|
|
|
|
| 37 |
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
|
| 38 |
|
| 39 |
# ---------------------------
|
| 40 |
+
# Config / Hyperparams (paths are placeholders; update for your environment)
|
| 41 |
# ---------------------------
|
| 42 |
P_MASK = 0.15
|
| 43 |
MAX_ATOMIC_Z = 85
|
|
|
|
| 48 |
EDGE_EMB_DIM = 300
|
| 49 |
NUM_GNN_LAYERS = 5
|
| 50 |
|
| 51 |
+
# SchNet params
|
| 52 |
SCHNET_NUM_GAUSSIANS = 50
|
| 53 |
SCHNET_NUM_INTERACTIONS = 6
|
| 54 |
SCHNET_CUTOFF = 10.0
|
| 55 |
SCHNET_MAX_NEIGHBORS = 64
|
| 56 |
SCHNET_HIDDEN = 600
|
| 57 |
|
| 58 |
+
# Transformer params
|
| 59 |
FP_LENGTH = 2048
|
| 60 |
+
MASK_TOKEN_ID_FP = 2
|
| 61 |
VOCAB_SIZE_FP = 3
|
| 62 |
|
| 63 |
+
# DeBERTav2 params
|
| 64 |
DEBERTA_HIDDEN = 600
|
| 65 |
PSMILES_MAX_LEN = 128
|
| 66 |
|
|
|
|
| 68 |
TEMPERATURE = 0.07
|
| 69 |
|
| 70 |
# Reconstruction loss weight (balance between contrastive and reconstruction objectives)
|
| 71 |
+
REC_LOSS_WEIGHT = 1.0 # could be tuned (e.g., 0.5, 1.0)
|
| 72 |
|
| 73 |
+
# Training args
|
| 74 |
+
OUTPUT_DIR = "/path/to/multimodal_output"
|
| 75 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 76 |
+
BEST_GINE_DIR = "/path/to/gin_output/best"
|
| 77 |
+
BEST_SCHNET_DIR = "/path/to/schnet_output/best"
|
| 78 |
+
BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
|
| 79 |
+
BEST_PSMILES_DIR = "/path/to/polybert_output/best"
|
| 80 |
|
| 81 |
training_args = TrainingArguments(
|
| 82 |
output_dir=OUTPUT_DIR,
|
|
|
|
| 122 |
# Utility / small helpers
|
| 123 |
# ---------------------------
|
| 124 |
|
| 125 |
+
# Optimized BFS to compute distances to visible anchors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_idx: np.ndarray, visible_idx: np.ndarray, k_anchors: int):
|
| 127 |
INF = num_nodes + 1
|
| 128 |
selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
|
|
|
|
| 163 |
return selected_dists, selected_mask
|
| 164 |
|
| 165 |
# ---------------------------
|
| 166 |
+
# Data loading / preprocessing
|
| 167 |
# ---------------------------
|
| 168 |
+
CSV_PATH = "/path/to/polymer_structures_unified_processed.csv"
|
| 169 |
TARGET_ROWS = 2000000
|
| 170 |
CHUNKSIZE = 50000
|
| 171 |
|
| 172 |
+
PREPROC_DIR = "/path/to/preprocessed_samples"
|
| 173 |
os.makedirs(PREPROC_DIR, exist_ok=True)
|
| 174 |
|
| 175 |
# The per-sample file format: torch.save(sample_dict, sample_path)
|
|
|
|
| 190 |
rows_read = 0
|
| 191 |
sample_idx = 0
|
| 192 |
|
|
|
|
| 193 |
for chunk in pd.read_csv(CSV_PATH, engine="python", chunksize=CHUNKSIZE):
|
| 194 |
# Pre-extract columns presence
|
| 195 |
has_graph = "graph" in chunk.columns
|
|
|
|
| 395 |
sample_files = prepare_or_load_data_streaming()
|
| 396 |
|
| 397 |
# ---------------------------
|
| 398 |
+
# Prepare tokenizer for psmiles
|
| 399 |
# ---------------------------
|
| 400 |
+
SPM_MODEL = "/path/to/spm.model"
|
| 401 |
+
tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
# ---------------------------
|
| 404 |
# Lazy dataset: loads per-sample file on demand and tokenizes psmiles on-the-fly
|
|
|
|
| 583 |
val_loader = DataLoader(val_subset, batch_size=training_args.per_device_eval_batch_size, shuffle=False, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
|
| 584 |
|
| 585 |
# ---------------------------
|
| 586 |
+
# Multimodal wrapper & loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
# ---------------------------
|
| 588 |
class MultimodalContrastiveModel(nn.Module):
|
| 589 |
def __init__(self,
|