Spaces:
Running
Running
manpreet88
commited on
Commit
·
21839e7
1
Parent(s):
6a3741a
Update CL.py
Browse files- PolyFusion/CL.py +826 -511
PolyFusion/CL.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import csv
|
|
@@ -22,7 +27,7 @@ import torch.nn as nn
|
|
| 22 |
import torch.nn.functional as F
|
| 23 |
from torch.utils.data import Dataset, DataLoader
|
| 24 |
|
| 25 |
-
# Shared model utilities
|
| 26 |
from GINE import GineEncoder, match_edge_attr_to_index, safe_get
|
| 27 |
from SchNet import NodeSchNetWrapper
|
| 28 |
from Transformer import PooledFingerprintEncoder as FingerprintEncoder
|
|
@@ -30,15 +35,15 @@ from DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
|
|
| 30 |
|
| 31 |
# HF Trainer & Transformers
|
| 32 |
from transformers import TrainingArguments, Trainer
|
| 33 |
-
from transformers import DataCollatorForLanguageModeling
|
| 34 |
from transformers.trainer_callback import TrainerCallback
|
| 35 |
|
| 36 |
from sklearn.model_selection import train_test_split
|
| 37 |
-
from sklearn.metrics import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
# ---------------------------
|
| 40 |
-
# Config / Hyperparams (paths are placeholders; update for your environment)
|
| 41 |
-
# ---------------------------
|
| 42 |
P_MASK = 0.15
|
| 43 |
MAX_ATOMIC_Z = 85
|
| 44 |
MASK_ATOM_ID = MAX_ATOMIC_Z + 1
|
|
@@ -48,104 +53,112 @@ NODE_EMB_DIM = 300
|
|
| 48 |
EDGE_EMB_DIM = 300
|
| 49 |
NUM_GNN_LAYERS = 5
|
| 50 |
|
| 51 |
-
# SchNet params
|
| 52 |
SCHNET_NUM_GAUSSIANS = 50
|
| 53 |
SCHNET_NUM_INTERACTIONS = 6
|
| 54 |
SCHNET_CUTOFF = 10.0
|
| 55 |
SCHNET_MAX_NEIGHBORS = 64
|
| 56 |
SCHNET_HIDDEN = 600
|
| 57 |
|
| 58 |
-
# Transformer params
|
| 59 |
FP_LENGTH = 2048
|
| 60 |
-
MASK_TOKEN_ID_FP = 2
|
| 61 |
VOCAB_SIZE_FP = 3
|
| 62 |
|
| 63 |
-
#
|
| 64 |
DEBERTA_HIDDEN = 600
|
| 65 |
PSMILES_MAX_LEN = 128
|
| 66 |
|
| 67 |
# Contrastive params
|
| 68 |
TEMPERATURE = 0.07
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
|
| 72 |
|
| 73 |
-
#
|
| 74 |
OUTPUT_DIR = "/path/to/multimodal_output"
|
| 75 |
-
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 76 |
BEST_GINE_DIR = "/path/to/gin_output/best"
|
| 77 |
BEST_SCHNET_DIR = "/path/to/schnet_output/best"
|
| 78 |
BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
|
| 79 |
BEST_PSMILES_DIR = "/path/to/polybert_output/best"
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
if USE_CUDA:
|
| 119 |
-
torch.cuda.manual_seed_all(SEED)
|
| 120 |
-
|
| 121 |
-
# ---------------------------
|
| 122 |
-
# Utility / small helpers
|
| 123 |
-
# ---------------------------
|
| 124 |
-
|
| 125 |
-
# Optimized BFS to compute distances to visible anchors
|
| 126 |
-
def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_idx: np.ndarray, visible_idx: np.ndarray, k_anchors: int):
|
| 127 |
-
INF = num_nodes + 1
|
| 128 |
selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
|
| 129 |
selected_mask = np.zeros((num_nodes, k_anchors), dtype=np.bool_)
|
|
|
|
| 130 |
if edge_index is None or edge_index.numel() == 0:
|
| 131 |
return selected_dists, selected_mask
|
|
|
|
| 132 |
src = edge_index[0].tolist()
|
| 133 |
dst = edge_index[1].tolist()
|
|
|
|
| 134 |
adj = [[] for _ in range(num_nodes)]
|
| 135 |
for u, v in zip(src, dst):
|
| 136 |
if 0 <= u < num_nodes and 0 <= v < num_nodes:
|
| 137 |
adj[u].append(v)
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
for a in np.atleast_1d(masked_idx).tolist():
|
| 140 |
if a < 0 or a >= num_nodes:
|
| 141 |
continue
|
|
|
|
| 142 |
q = [a]
|
| 143 |
visited = [-1] * num_nodes
|
| 144 |
visited[a] = 0
|
| 145 |
head = 0
|
| 146 |
found = []
|
|
|
|
| 147 |
while head < len(q) and len(found) < k_anchors:
|
| 148 |
-
u = q[head]
|
|
|
|
| 149 |
for v in adj[u]:
|
| 150 |
if visited[v] == -1:
|
| 151 |
visited[v] = visited[u] + 1
|
|
@@ -154,137 +167,162 @@ def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_id
|
|
| 154 |
found.append((visited[v], v))
|
| 155 |
if len(found) >= k_anchors:
|
| 156 |
break
|
|
|
|
| 157 |
if len(found) > 0:
|
| 158 |
found.sort(key=lambda x: x[0])
|
| 159 |
k = min(k_anchors, len(found))
|
| 160 |
for i in range(k):
|
| 161 |
selected_dists[a, i] = float(found[i][0])
|
| 162 |
selected_mask[a, i] = True
|
|
|
|
| 163 |
return selected_dists, selected_mask
|
| 164 |
|
| 165 |
-
# ---------------------------
|
| 166 |
-
# Data loading / preprocessing
|
| 167 |
-
# ---------------------------
|
| 168 |
-
CSV_PATH = "/path/to/polymer_structures_unified_processed.csv"
|
| 169 |
-
TARGET_ROWS = 2000000
|
| 170 |
-
CHUNKSIZE = 50000
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
if len(existing) > 0:
|
| 186 |
-
print(f"Found {len(existing)} preprocessed sample files in {
|
| 187 |
return [str(p) for p in existing]
|
| 188 |
|
| 189 |
print("No existing per-sample preprocessed folder found. Parsing CSV chunked and writing per-sample files (streaming).")
|
| 190 |
-
|
| 191 |
sample_idx = 0
|
| 192 |
|
| 193 |
-
for chunk in pd.read_csv(
|
| 194 |
-
# Pre-extract columns presence
|
| 195 |
has_graph = "graph" in chunk.columns
|
| 196 |
has_geometry = "geometry" in chunk.columns
|
| 197 |
has_fp = "fingerprints" in chunk.columns
|
| 198 |
has_psmiles = "psmiles" in chunk.columns
|
| 199 |
|
| 200 |
for i_row in range(len(chunk)):
|
| 201 |
-
if
|
| 202 |
break
|
|
|
|
| 203 |
row = chunk.iloc[i_row]
|
| 204 |
|
| 205 |
-
#
|
| 206 |
gine_sample = None
|
| 207 |
schnet_sample = None
|
| 208 |
fp_sample = None
|
| 209 |
psmiles_raw = None
|
| 210 |
|
| 211 |
-
#
|
| 212 |
if has_graph:
|
| 213 |
val = row.get("graph", "")
|
| 214 |
try:
|
| 215 |
-
graph_field =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
except Exception:
|
| 217 |
graph_field = None
|
|
|
|
| 218 |
if graph_field:
|
| 219 |
node_features = safe_get(graph_field, "node_features", None)
|
| 220 |
if node_features:
|
| 221 |
atomic_nums = []
|
| 222 |
chirality_vals = []
|
| 223 |
formal_charges = []
|
|
|
|
| 224 |
for nf in node_features:
|
| 225 |
an = safe_get(nf, "atomic_num", None)
|
| 226 |
if an is None:
|
| 227 |
an = safe_get(nf, "atomic_number", 0)
|
| 228 |
ch = safe_get(nf, "chirality", 0)
|
| 229 |
fc = safe_get(nf, "formal_charge", 0)
|
|
|
|
| 230 |
try:
|
| 231 |
atomic_nums.append(int(an))
|
| 232 |
except Exception:
|
| 233 |
atomic_nums.append(0)
|
|
|
|
| 234 |
chirality_vals.append(float(ch))
|
| 235 |
formal_charges.append(float(fc))
|
| 236 |
-
|
| 237 |
edge_indices_raw = safe_get(graph_field, "edge_indices", None)
|
| 238 |
edge_features_raw = safe_get(graph_field, "edge_features", None)
|
|
|
|
| 239 |
edge_index = None
|
| 240 |
edge_attr = None
|
|
|
|
|
|
|
| 241 |
if edge_indices_raw is None:
|
| 242 |
adj_mat = safe_get(graph_field, "adjacency_matrix", None)
|
| 243 |
if adj_mat:
|
| 244 |
-
srcs = []
|
| 245 |
-
dsts = []
|
| 246 |
for i_r, row_adj in enumerate(adj_mat):
|
| 247 |
for j, val2 in enumerate(row_adj):
|
| 248 |
if val2:
|
| 249 |
-
srcs.append(i_r)
|
|
|
|
| 250 |
if len(srcs) > 0:
|
| 251 |
edge_index = [srcs, dsts]
|
| 252 |
E = len(srcs)
|
| 253 |
edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
|
| 254 |
else:
|
|
|
|
|
|
|
|
|
|
| 255 |
srcs, dsts = [], []
|
| 256 |
-
|
| 257 |
if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list):
|
| 258 |
-
# either list of pairs or two lists
|
| 259 |
first = edge_indices_raw[0]
|
| 260 |
if len(first) == 2 and isinstance(first[0], int):
|
| 261 |
-
#
|
| 262 |
try:
|
| 263 |
srcs = [int(p[0]) for p in edge_indices_raw]
|
| 264 |
dsts = [int(p[1]) for p in edge_indices_raw]
|
| 265 |
except Exception:
|
| 266 |
srcs, dsts = [], []
|
| 267 |
else:
|
| 268 |
-
#
|
| 269 |
try:
|
| 270 |
srcs = [int(x) for x in edge_indices_raw[0]]
|
| 271 |
dsts = [int(x) for x in edge_indices_raw[1]]
|
| 272 |
except Exception:
|
| 273 |
srcs, dsts = [], []
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
| 275 |
srcs = [int(p[0]) for p in edge_indices_raw]
|
| 276 |
dsts = [int(p[1]) for p in edge_indices_raw]
|
|
|
|
| 277 |
if len(srcs) > 0:
|
| 278 |
edge_index = [srcs, dsts]
|
|
|
|
| 279 |
if edge_features_raw and isinstance(edge_features_raw, list):
|
| 280 |
-
bond_types = []
|
| 281 |
-
stereos = []
|
| 282 |
-
is_conjs = []
|
| 283 |
for ef in edge_features_raw:
|
| 284 |
bt = safe_get(ef, "bond_type", 0)
|
| 285 |
st = safe_get(ef, "stereo", 0)
|
| 286 |
ic = safe_get(ef, "is_conjugated", False)
|
| 287 |
-
bond_types.append(float(bt))
|
|
|
|
|
|
|
| 288 |
edge_attr = list(zip(bond_types, stereos, is_conjs))
|
| 289 |
else:
|
| 290 |
E = len(srcs)
|
|
@@ -299,11 +337,15 @@ def prepare_or_load_data_streaming():
|
|
| 299 |
"edge_attr": edge_attr,
|
| 300 |
}
|
| 301 |
|
| 302 |
-
#
|
| 303 |
if has_geometry and schnet_sample is None:
|
| 304 |
val = row.get("geometry", "")
|
| 305 |
try:
|
| 306 |
-
geom =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
conf = geom.get("best_conformer") if isinstance(geom, dict) else None
|
| 308 |
if conf:
|
| 309 |
atomic = conf.get("atomic_numbers", [])
|
|
@@ -313,12 +355,13 @@ def prepare_or_load_data_streaming():
|
|
| 313 |
except Exception:
|
| 314 |
schnet_sample = None
|
| 315 |
|
| 316 |
-
#
|
| 317 |
if has_fp:
|
| 318 |
fpval = row.get("fingerprints", "")
|
| 319 |
if fpval is None or (isinstance(fpval, str) and fpval.strip() == ""):
|
| 320 |
fp_sample = [0] * FP_LENGTH
|
| 321 |
else:
|
|
|
|
| 322 |
try:
|
| 323 |
fp_json = json.loads(fpval) if isinstance(fpval, str) else fpval
|
| 324 |
except Exception:
|
|
@@ -330,8 +373,13 @@ def prepare_or_load_data_streaming():
|
|
| 330 |
if len(bits) < FP_LENGTH:
|
| 331 |
bits += [0] * (FP_LENGTH - len(bits))
|
| 332 |
fp_sample = bits
|
|
|
|
| 333 |
if fp_sample is None:
|
| 334 |
-
bits =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
if bits is None:
|
| 336 |
fp_sample = [0] * FP_LENGTH
|
| 337 |
else:
|
|
@@ -344,107 +392,122 @@ def prepare_or_load_data_streaming():
|
|
| 344 |
normalized.append(1 if int(b) != 0 else 0)
|
| 345 |
else:
|
| 346 |
normalized.append(0)
|
|
|
|
| 347 |
if len(normalized) >= FP_LENGTH:
|
| 348 |
break
|
|
|
|
| 349 |
if len(normalized) < FP_LENGTH:
|
| 350 |
normalized.extend([0] * (FP_LENGTH - len(normalized)))
|
| 351 |
fp_sample = normalized[:FP_LENGTH]
|
| 352 |
|
| 353 |
-
#
|
| 354 |
if has_psmiles:
|
| 355 |
s = row.get("psmiles", "")
|
| 356 |
-
if s is None
|
| 357 |
-
psmiles_raw = ""
|
| 358 |
-
else:
|
| 359 |
-
psmiles_raw = str(s)
|
| 360 |
|
| 361 |
-
#
|
| 362 |
-
|
| 363 |
-
|
|
|
|
| 364 |
if modalities_present >= 2:
|
| 365 |
sample = {
|
| 366 |
"gine": gine_sample,
|
| 367 |
"schnet": schnet_sample,
|
| 368 |
"fp": fp_sample,
|
| 369 |
-
"psmiles_raw": psmiles_raw
|
| 370 |
}
|
| 371 |
-
|
|
|
|
| 372 |
try:
|
| 373 |
torch.save(sample, sample_path)
|
| 374 |
except Exception as save_e:
|
| 375 |
print("Warning: failed to torch.save sample:", save_e)
|
| 376 |
-
# fallback
|
| 377 |
try:
|
| 378 |
with open(sample_path + ".json", "w") as fjson:
|
| 379 |
json.dump(sample, fjson)
|
| 380 |
-
# indicate via filename with .json
|
| 381 |
-
sample_path = sample_path + ".json"
|
| 382 |
except Exception:
|
| 383 |
pass
|
| 384 |
|
| 385 |
sample_idx += 1
|
| 386 |
-
|
| 387 |
|
| 388 |
-
|
| 389 |
-
if rows_read >= TARGET_ROWS:
|
| 390 |
break
|
| 391 |
|
| 392 |
-
print(f"Wrote {sample_idx} sample files to {
|
| 393 |
-
return [str(p) for p in sorted(Path(
|
| 394 |
|
| 395 |
-
sample_files = prepare_or_load_data_streaming()
|
| 396 |
|
| 397 |
-
#
|
| 398 |
-
#
|
| 399 |
-
#
|
| 400 |
-
SPM_MODEL = "/path/to/spm.model"
|
| 401 |
-
tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
|
| 402 |
|
| 403 |
-
# ---------------------------
|
| 404 |
-
# Lazy dataset: loads per-sample file on demand and tokenizes psmiles on-the-fly
|
| 405 |
-
# ---------------------------
|
| 406 |
class LazyMultimodalDataset(Dataset):
|
| 407 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
self.files = sample_file_list
|
| 409 |
self.tokenizer = tokenizer
|
| 410 |
self.fp_length = fp_length
|
| 411 |
self.psmiles_max_len = psmiles_max_len
|
| 412 |
|
| 413 |
-
def __len__(self):
|
| 414 |
return len(self.files)
|
| 415 |
|
| 416 |
-
def __getitem__(self, idx):
|
| 417 |
sample_path = self.files[idx]
|
| 418 |
-
|
|
|
|
| 419 |
if sample_path.endswith(".pt"):
|
| 420 |
sample = torch.load(sample_path, map_location="cpu")
|
| 421 |
else:
|
| 422 |
-
# fallback json load
|
| 423 |
with open(sample_path, "r") as f:
|
| 424 |
sample = json.load(f)
|
| 425 |
|
| 426 |
-
# GINE
|
| 427 |
gine_raw = sample.get("gine", None)
|
| 428 |
-
gine_item = None
|
| 429 |
if gine_raw:
|
| 430 |
node_atomic = torch.tensor(gine_raw.get("node_atomic", []), dtype=torch.long)
|
| 431 |
node_chirality = torch.tensor(gine_raw.get("node_chirality", []), dtype=torch.float)
|
| 432 |
node_charge = torch.tensor(gine_raw.get("node_charge", []), dtype=torch.float)
|
|
|
|
| 433 |
if gine_raw.get("edge_index", None) is not None:
|
| 434 |
-
|
| 435 |
-
edge_index = torch.tensor(ei, dtype=torch.long)
|
| 436 |
else:
|
| 437 |
edge_index = torch.tensor([[], []], dtype=torch.long)
|
|
|
|
| 438 |
ea_raw = gine_raw.get("edge_attr", None)
|
| 439 |
if ea_raw:
|
| 440 |
edge_attr = torch.tensor(ea_raw, dtype=torch.float)
|
| 441 |
else:
|
| 442 |
edge_attr = torch.zeros((edge_index.size(1), 3), dtype=torch.float)
|
| 443 |
-
gine_item = {"z": node_atomic, "chirality": node_chirality, "formal_charge": node_charge, "edge_index": edge_index, "edge_attr": edge_attr}
|
| 444 |
-
else:
|
| 445 |
-
gine_item = {"z": torch.tensor([], dtype=torch.long), "chirality": torch.tensor([], dtype=torch.float), "formal_charge": torch.tensor([], dtype=torch.float), "edge_index": torch.tensor([[], []], dtype=torch.long), "edge_attr": torch.zeros((0, 3), dtype=torch.float)}
|
| 446 |
|
| 447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
schnet_raw = sample.get("schnet", None)
|
| 449 |
if schnet_raw:
|
| 450 |
s_z = torch.tensor(schnet_raw.get("atomic", []), dtype=torch.long)
|
|
@@ -453,12 +516,11 @@ class LazyMultimodalDataset(Dataset):
|
|
| 453 |
else:
|
| 454 |
schnet_item = {"z": torch.tensor([], dtype=torch.long), "pos": torch.tensor([], dtype=torch.float)}
|
| 455 |
|
| 456 |
-
# Fingerprint
|
| 457 |
fp_raw = sample.get("fp", None)
|
| 458 |
if fp_raw is None:
|
| 459 |
fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
|
| 460 |
else:
|
| 461 |
-
# if fp_raw is already tensor-like, handle it
|
| 462 |
if isinstance(fp_raw, (list, tuple)):
|
| 463 |
arr = list(fp_raw)[:self.fp_length]
|
| 464 |
if len(arr) < self.fp_length:
|
|
@@ -467,85 +529,87 @@ class LazyMultimodalDataset(Dataset):
|
|
| 467 |
elif isinstance(fp_raw, torch.Tensor):
|
| 468 |
fp_vec = fp_raw.clone().to(torch.long)
|
| 469 |
else:
|
| 470 |
-
# fallback
|
| 471 |
fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
|
| 472 |
|
| 473 |
-
# PSMILES
|
| 474 |
-
psm_raw = sample.get("psmiles_raw", "")
|
| 475 |
-
if psm_raw is None:
|
| 476 |
-
psm_raw = ""
|
| 477 |
enc = self.tokenizer(psm_raw, truncation=True, padding="max_length", max_length=self.psmiles_max_len)
|
| 478 |
p_input_ids = torch.tensor(enc["input_ids"], dtype=torch.long)
|
| 479 |
p_attn = torch.tensor(enc["attention_mask"], dtype=torch.bool)
|
| 480 |
|
| 481 |
return {
|
| 482 |
-
"gine": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
"schnet": {"z": schnet_item["z"], "pos": schnet_item["pos"]},
|
| 484 |
"fp": {"input_ids": fp_vec},
|
| 485 |
-
"psmiles": {"input_ids": p_input_ids, "attention_mask": p_attn}
|
| 486 |
}
|
| 487 |
|
| 488 |
-
# instantiate dataset lazily
|
| 489 |
-
dataset = LazyMultimodalDataset(sample_files, tokenizer, fp_length=FP_LENGTH, psmiles_max_len=PSMILES_MAX_LEN)
|
| 490 |
|
| 491 |
-
|
| 492 |
-
train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)
|
| 493 |
-
train_subset = torch.utils.data.Subset(dataset, train_idx)
|
| 494 |
-
val_subset = torch.utils.data.Subset(dataset, val_idx)
|
| 495 |
-
|
| 496 |
-
# For manual evaluation (used by evaluate_multimodal), create a DataLoader with num_workers=0
|
| 497 |
-
def multimodal_collate(batch_list):
|
| 498 |
"""
|
| 499 |
-
|
| 500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
"""
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
all_ch = []
|
| 506 |
-
all_fc = []
|
| 507 |
-
all_edge_index = []
|
| 508 |
-
all_edge_attr = []
|
| 509 |
batch_mapping = []
|
| 510 |
node_offset = 0
|
|
|
|
| 511 |
for i, item in enumerate(batch_list):
|
| 512 |
g = item["gine"]
|
| 513 |
z = g["z"]
|
| 514 |
n = z.size(0)
|
|
|
|
| 515 |
all_z.append(z)
|
| 516 |
all_ch.append(g["chirality"])
|
| 517 |
all_fc.append(g["formal_charge"])
|
| 518 |
batch_mapping.append(torch.full((n,), i, dtype=torch.long))
|
|
|
|
| 519 |
if g["edge_index"] is not None and g["edge_index"].numel() > 0:
|
| 520 |
ei_offset = g["edge_index"] + node_offset
|
| 521 |
all_edge_index.append(ei_offset)
|
|
|
|
|
|
|
| 522 |
ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
|
| 523 |
all_edge_attr.append(ea)
|
|
|
|
| 524 |
node_offset += n
|
|
|
|
| 525 |
if len(all_z) == 0:
|
| 526 |
-
# create zero-length placeholders for empty batch
|
| 527 |
z_batch = torch.tensor([], dtype=torch.long)
|
| 528 |
ch_batch = torch.tensor([], dtype=torch.float)
|
| 529 |
fc_batch = torch.tensor([], dtype=torch.float)
|
| 530 |
batch_batch = torch.tensor([], dtype=torch.long)
|
| 531 |
-
edge_index_batched = torch.empty((2,0), dtype=torch.long)
|
| 532 |
-
edge_attr_batched = torch.zeros((0,3), dtype=torch.float)
|
| 533 |
else:
|
| 534 |
z_batch = torch.cat(all_z, dim=0)
|
| 535 |
ch_batch = torch.cat(all_ch, dim=0)
|
| 536 |
fc_batch = torch.cat(all_fc, dim=0)
|
| 537 |
batch_batch = torch.cat(batch_mapping, dim=0)
|
|
|
|
| 538 |
if len(all_edge_index) > 0:
|
| 539 |
edge_index_batched = torch.cat(all_edge_index, dim=1)
|
| 540 |
edge_attr_batched = torch.cat(all_edge_attr, dim=0)
|
| 541 |
else:
|
| 542 |
-
edge_index_batched = torch.empty((2,0), dtype=torch.long)
|
| 543 |
-
edge_attr_batched = torch.zeros((0,3), dtype=torch.float)
|
| 544 |
|
| 545 |
-
# SchNet batching
|
| 546 |
-
all_sz = []
|
| 547 |
-
all_pos = []
|
| 548 |
-
schnet_batch = []
|
| 549 |
for i, item in enumerate(batch_list):
|
| 550 |
s = item["schnet"]
|
| 551 |
s_z = s["z"]
|
|
@@ -555,6 +619,7 @@ def multimodal_collate(batch_list):
|
|
| 555 |
all_sz.append(s_z)
|
| 556 |
all_pos.append(s_pos)
|
| 557 |
schnet_batch.append(torch.full((s_z.size(0),), i, dtype=torch.long))
|
|
|
|
| 558 |
if len(all_sz) == 0:
|
| 559 |
s_z_batch = torch.tensor([], dtype=torch.long)
|
| 560 |
s_pos_batch = torch.tensor([], dtype=torch.float)
|
|
@@ -564,375 +629,467 @@ def multimodal_collate(batch_list):
|
|
| 564 |
s_pos_batch = torch.cat(all_pos, dim=0)
|
| 565 |
s_batch_batch = torch.cat(schnet_batch, dim=0)
|
| 566 |
|
| 567 |
-
# FP batching
|
| 568 |
-
fp_ids = torch.stack(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
fp_attn = torch.ones_like(fp_ids, dtype=torch.bool)
|
| 570 |
|
| 571 |
-
# PSMILES
|
| 572 |
p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch_list], dim=0)
|
| 573 |
p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch_list], dim=0)
|
| 574 |
|
| 575 |
return {
|
| 576 |
-
"gine": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
"schnet": {"z": s_z_batch, "pos": s_pos_batch, "batch": s_batch_batch},
|
| 578 |
"fp": {"input_ids": fp_ids, "attention_mask": fp_attn},
|
| 579 |
-
"psmiles": {"input_ids": p_ids, "attention_mask": p_attn}
|
| 580 |
}
|
| 581 |
|
| 582 |
-
train_loader = DataLoader(train_subset, batch_size=training_args.per_device_train_batch_size, shuffle=True, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
|
| 583 |
-
val_loader = DataLoader(val_subset, batch_size=training_args.per_device_eval_batch_size, shuffle=False, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
|
| 584 |
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
class MultimodalContrastiveModel(nn.Module):
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
super().__init__()
|
| 596 |
self.gine = gine_encoder
|
| 597 |
self.schnet = schnet_encoder
|
| 598 |
self.fp = fp_encoder
|
| 599 |
self.psmiles = psmiles_encoder
|
|
|
|
| 600 |
self.proj_gine = nn.Linear(getattr(self.gine, "pool_proj").out_features if self.gine is not None else emb_dim, emb_dim) if self.gine is not None else None
|
| 601 |
self.proj_schnet = nn.Linear(getattr(self.schnet, "pool_proj").out_features if self.schnet is not None else emb_dim, emb_dim) if self.schnet is not None else None
|
| 602 |
self.proj_fp = nn.Linear(getattr(self.fp, "pool_proj").out_features if self.fp is not None else emb_dim, emb_dim) if self.fp is not None else None
|
| 603 |
self.proj_psmiles = nn.Linear(getattr(self.psmiles, "pool_proj").out_features if self.psmiles is not None else emb_dim, emb_dim) if self.psmiles is not None else None
|
|
|
|
| 604 |
self.temperature = TEMPERATURE
|
| 605 |
-
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100, reduction=
|
| 606 |
|
| 607 |
def encode(self, batch_mods: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 608 |
-
|
| 609 |
embs = {}
|
| 610 |
-
|
| 611 |
-
if
|
| 612 |
-
g = batch_mods[
|
| 613 |
-
emb_g = self.gine(g[
|
| 614 |
-
embs[
|
| 615 |
-
|
| 616 |
-
if
|
| 617 |
-
s = batch_mods[
|
| 618 |
-
emb_s = self.schnet(s[
|
| 619 |
-
embs[
|
| 620 |
-
|
| 621 |
-
if
|
| 622 |
-
f = batch_mods[
|
| 623 |
-
emb_f = self.fp(f[
|
| 624 |
-
embs[
|
| 625 |
-
|
| 626 |
-
if
|
| 627 |
-
p = batch_mods[
|
| 628 |
-
emb_p = self.psmiles(p[
|
| 629 |
-
embs[
|
| 630 |
-
|
| 631 |
return embs
|
| 632 |
|
| 633 |
def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
|
|
|
|
|
|
|
|
|
|
| 634 |
device = next(self.parameters()).device
|
| 635 |
embs = self.encode(batch_mods)
|
| 636 |
info = {}
|
|
|
|
| 637 |
if mask_target not in embs:
|
| 638 |
return torch.tensor(0.0, device=device), {"batch_size": 0}
|
|
|
|
| 639 |
target = embs[mask_target]
|
| 640 |
other_keys = [k for k in embs.keys() if k != mask_target]
|
| 641 |
if len(other_keys) == 0:
|
| 642 |
return torch.tensor(0.0, device=device), {"batch_size": target.size(0)}
|
|
|
|
| 643 |
anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
|
| 644 |
logits = torch.matmul(anchor, target.T) / self.temperature
|
| 645 |
B = logits.size(0)
|
| 646 |
labels = torch.arange(B, device=logits.device)
|
| 647 |
info_nce_loss = F.cross_entropy(logits, labels)
|
| 648 |
-
info[
|
| 649 |
|
|
|
|
| 650 |
rec_losses = []
|
| 651 |
rec_details = {}
|
| 652 |
|
|
|
|
| 653 |
try:
|
| 654 |
-
if
|
| 655 |
-
gm = batch_mods[
|
| 656 |
-
labels_nodes = gm.get(
|
| 657 |
if labels_nodes is not None:
|
| 658 |
-
node_logits = self.gine.node_logits(gm[
|
| 659 |
if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
|
| 660 |
loss_gine = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
|
| 661 |
rec_losses.append(loss_gine)
|
| 662 |
-
rec_details[
|
| 663 |
except Exception as e:
|
| 664 |
print("Warning: GINE reconstruction loss computation failed:", e)
|
| 665 |
|
|
|
|
| 666 |
try:
|
| 667 |
-
if
|
| 668 |
-
sm = batch_mods[
|
| 669 |
-
labels_nodes = sm.get(
|
| 670 |
if labels_nodes is not None:
|
| 671 |
-
node_logits = self.schnet.node_logits(sm[
|
| 672 |
if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
|
| 673 |
loss_schnet = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
|
| 674 |
rec_losses.append(loss_schnet)
|
| 675 |
-
rec_details[
|
| 676 |
except Exception as e:
|
| 677 |
print("Warning: SchNet reconstruction loss computation failed:", e)
|
| 678 |
|
|
|
|
| 679 |
try:
|
| 680 |
-
if
|
| 681 |
-
fm = batch_mods[
|
| 682 |
-
labels_fp = fm.get(
|
| 683 |
if labels_fp is not None:
|
| 684 |
-
token_logits = self.fp.token_logits(fm[
|
| 685 |
Bf, Lf, V = token_logits.shape
|
| 686 |
logits2 = token_logits.view(-1, V)
|
| 687 |
labels2 = labels_fp.view(-1).to(logits2.device)
|
| 688 |
loss_fp = self.ce_loss(logits2, labels2)
|
| 689 |
rec_losses.append(loss_fp)
|
| 690 |
-
rec_details[
|
| 691 |
except Exception as e:
|
| 692 |
print("Warning: FP reconstruction loss computation failed:", e)
|
| 693 |
|
|
|
|
| 694 |
try:
|
| 695 |
-
if
|
| 696 |
-
pm = batch_mods[
|
| 697 |
-
labels_ps = pm.get(
|
| 698 |
-
if labels_ps is not None
|
| 699 |
-
loss_ps = self.psmiles.token_logits(pm[
|
| 700 |
if isinstance(loss_ps, torch.Tensor):
|
| 701 |
rec_losses.append(loss_ps)
|
| 702 |
-
rec_details[
|
| 703 |
except Exception as e:
|
| 704 |
print("Warning: PSMILES MLM loss computation failed:", e)
|
| 705 |
|
| 706 |
if len(rec_losses) > 0:
|
| 707 |
rec_loss_total = sum(rec_losses) / len(rec_losses)
|
| 708 |
-
info[
|
| 709 |
total_loss = info_nce_loss + REC_LOSS_WEIGHT * rec_loss_total
|
| 710 |
-
info[
|
| 711 |
info.update(rec_details)
|
| 712 |
else:
|
| 713 |
total_loss = info_nce_loss
|
| 714 |
-
info[
|
| 715 |
-
info[
|
| 716 |
|
| 717 |
return total_loss, info
|
| 718 |
|
| 719 |
-
# ---------------------------
|
| 720 |
-
# Instantiate encoders (load weights if available) and move to device with .to(device)
|
| 721 |
-
gine_encoder = GineEncoder(node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z)
|
| 722 |
-
if os.path.exists(os.path.join(BEST_GINE_DIR, "pytorch_model.bin")):
|
| 723 |
-
try:
|
| 724 |
-
gine_encoder.load_state_dict(torch.load(os.path.join(BEST_GINE_DIR, "pytorch_model.bin"), map_location="cpu"), strict=False)
|
| 725 |
-
print("Loaded GINE best weights from", BEST_GINE_DIR)
|
| 726 |
-
except Exception as e:
|
| 727 |
-
print("Could not load GINE best weights:", e)
|
| 728 |
-
gine_encoder.to(device)
|
| 729 |
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
schnet_encoder.load_state_dict(torch.load(os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin"), map_location="cpu"), strict=False)
|
| 734 |
-
print("Loaded SchNet best weights from", BEST_SCHNET_DIR)
|
| 735 |
-
except Exception as e:
|
| 736 |
-
print("Could not load SchNet best weights:", e)
|
| 737 |
-
schnet_encoder.to(device)
|
| 738 |
-
|
| 739 |
-
fp_encoder = FingerprintEncoder(vocab_size=VOCAB_SIZE_FP, hidden_dim=256, seq_len=FP_LENGTH, num_layers=4, nhead=8, dim_feedforward=1024, dropout=0.1)
|
| 740 |
-
if os.path.exists(os.path.join(BEST_FP_DIR, "pytorch_model.bin")):
|
| 741 |
-
try:
|
| 742 |
-
fp_encoder.load_state_dict(torch.load(os.path.join(BEST_FP_DIR, "pytorch_model.bin"), map_location="cpu"), strict=False)
|
| 743 |
-
print("Loaded fingerprint encoder best weights from", BEST_FP_DIR)
|
| 744 |
-
except Exception as e:
|
| 745 |
-
print("Could not load fingerprint best weights:", e)
|
| 746 |
-
fp_encoder.to(device)
|
| 747 |
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
print("Failed to load Deberta from saved directory:", e)
|
| 755 |
-
if psmiles_encoder is None:
|
| 756 |
-
try:
|
| 757 |
-
psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=None)
|
| 758 |
-
except Exception as e:
|
| 759 |
-
print("Failed to instantiate Deberta encoder:", e)
|
| 760 |
-
psmiles_encoder.to(device)
|
| 761 |
|
| 762 |
-
|
| 763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
|
| 765 |
-
# ---------------------------
|
| 766 |
-
# Helper to sample masked variant for modalities: (kept same, device-safe)
|
| 767 |
-
def mask_batch_for_modality(batch: dict, modality: str, p_mask: float = P_MASK):
|
| 768 |
-
b = {}
|
| 769 |
-
# GINE:
|
| 770 |
-
if 'gine' in batch:
|
| 771 |
-
z = batch['gine']['z'].clone()
|
| 772 |
-
chir = batch['gine']['chirality'].clone()
|
| 773 |
-
fc = batch['gine']['formal_charge'].clone()
|
| 774 |
-
edge_index = batch['gine']['edge_index']
|
| 775 |
-
edge_attr = batch['gine']['edge_attr']
|
| 776 |
-
batch_map = batch['gine'].get('batch', None)
|
| 777 |
n_nodes = z.size(0)
|
| 778 |
dev = z.device
|
| 779 |
is_selected = torch.rand(n_nodes, device=dev) < p_mask
|
| 780 |
if is_selected.numel() > 0 and is_selected.all():
|
| 781 |
is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
|
|
|
|
| 782 |
labels_z = torch.full_like(z, fill_value=-100)
|
| 783 |
if is_selected.any():
|
| 784 |
sel_idx = torch.nonzero(is_selected).squeeze(-1)
|
| 785 |
if sel_idx.dim() == 0:
|
| 786 |
sel_idx = sel_idx.unsqueeze(0)
|
|
|
|
| 787 |
labels_z[is_selected] = z[is_selected]
|
| 788 |
-
rand_atomic = torch.randint(1, MAX_ATOMIC_Z+1, (sel_idx.size(0),), dtype=torch.long, device=dev)
|
|
|
|
| 789 |
probs = torch.rand(sel_idx.size(0), device=dev)
|
| 790 |
mask_choice = probs < 0.8
|
| 791 |
rand_choice = (probs >= 0.8) & (probs < 0.9)
|
|
|
|
| 792 |
if mask_choice.any():
|
| 793 |
z[sel_idx[mask_choice]] = MASK_ATOM_ID
|
| 794 |
if rand_choice.any():
|
| 795 |
z[sel_idx[rand_choice]] = rand_atomic[rand_choice]
|
| 796 |
-
b['gine'] = {"z": z, "chirality": chir, "formal_charge": fc, "edge_index": edge_index, "edge_attr": edge_attr, "batch": batch_map, "labels": labels_z}
|
| 797 |
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
n_nodes = z.size(0)
|
| 804 |
dev = z.device
|
| 805 |
is_selected = torch.rand(n_nodes, device=dev) < p_mask
|
| 806 |
if is_selected.numel() > 0 and is_selected.all():
|
| 807 |
is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
|
|
|
|
| 808 |
labels_z = torch.full((n_nodes,), -100, dtype=torch.long, device=dev)
|
| 809 |
if is_selected.any():
|
| 810 |
sel_idx = torch.nonzero(is_selected).squeeze(-1)
|
| 811 |
if sel_idx.dim() == 0:
|
| 812 |
sel_idx = sel_idx.unsqueeze(0)
|
|
|
|
| 813 |
labels_z[is_selected] = z[is_selected]
|
| 814 |
probs_c = torch.rand(sel_idx.size(0), device=dev)
|
| 815 |
noisy_choice = probs_c < 0.8
|
| 816 |
randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9)
|
|
|
|
| 817 |
if noisy_choice.any():
|
| 818 |
idx = sel_idx[noisy_choice]
|
| 819 |
noise = torch.randn((idx.size(0), 3), device=pos.device) * 0.5
|
| 820 |
pos[idx] = pos[idx] + noise
|
|
|
|
| 821 |
if randpos_choice.any():
|
| 822 |
idx = sel_idx[randpos_choice]
|
| 823 |
mins = pos.min(dim=0).values
|
| 824 |
maxs = pos.max(dim=0).values
|
| 825 |
randpos = (torch.rand((idx.size(0), 3), device=pos.device) * (maxs - mins)) + mins
|
| 826 |
pos[idx] = randpos
|
| 827 |
-
b['schnet'] = {"z": z, "pos": pos, "batch": batch_map, "labels": labels_z}
|
| 828 |
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
|
|
|
|
|
|
|
|
|
| 833 |
B, L = input_ids.shape
|
| 834 |
dev = input_ids.device
|
| 835 |
labels_z = torch.full_like(input_ids, -100)
|
|
|
|
| 836 |
for i in range(B):
|
| 837 |
sel = torch.rand(L, device=dev) < p_mask
|
| 838 |
if sel.numel() > 0 and sel.all():
|
| 839 |
sel[torch.randint(0, L, (1,), device=dev)] = False
|
|
|
|
| 840 |
sel_idx = torch.nonzero(sel).squeeze(-1)
|
| 841 |
if sel_idx.numel() > 0:
|
| 842 |
if sel_idx.dim() == 0:
|
| 843 |
sel_idx = sel_idx.unsqueeze(0)
|
|
|
|
| 844 |
labels_z[i, sel_idx] = input_ids[i, sel_idx]
|
|
|
|
| 845 |
probs = torch.rand(sel_idx.size(0), device=dev)
|
| 846 |
mask_choice = probs < 0.8
|
| 847 |
rand_choice = (probs >= 0.8) & (probs < 0.9)
|
|
|
|
| 848 |
if mask_choice.any():
|
| 849 |
input_ids[i, sel_idx[mask_choice]] = MASK_TOKEN_ID_FP
|
| 850 |
if rand_choice.any():
|
| 851 |
rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long, device=dev)
|
| 852 |
input_ids[i, sel_idx[rand_choice]] = rand_bits
|
| 853 |
-
b['fp'] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
|
| 854 |
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
|
|
|
|
|
|
|
|
|
| 859 |
B, L = input_ids.shape
|
| 860 |
dev = input_ids.device
|
| 861 |
labels_z = torch.full_like(input_ids, -100)
|
|
|
|
|
|
|
| 862 |
if tokenizer is None:
|
| 863 |
-
b[
|
| 864 |
else:
|
| 865 |
mask_token_id = tokenizer.mask_token_id if getattr(tokenizer, "mask_token_id", None) is not None else getattr(tokenizer, "vocab", {}).get("<mask>", 1)
|
|
|
|
| 866 |
for i in range(B):
|
| 867 |
sel = torch.rand(L, device=dev) < p_mask
|
| 868 |
if sel.numel() > 0 and sel.all():
|
| 869 |
sel[torch.randint(0, L, (1,), device=dev)] = False
|
|
|
|
| 870 |
sel_idx = torch.nonzero(sel).squeeze(-1)
|
| 871 |
if sel_idx.numel() > 0:
|
| 872 |
if sel_idx.dim() == 0:
|
| 873 |
sel_idx = sel_idx.unsqueeze(0)
|
|
|
|
| 874 |
labels_z[i, sel_idx] = input_ids[i, sel_idx]
|
|
|
|
| 875 |
probs = torch.rand(sel_idx.size(0), device=dev)
|
| 876 |
mask_choice = probs < 0.8
|
| 877 |
rand_choice = (probs >= 0.8) & (probs < 0.9)
|
|
|
|
| 878 |
if mask_choice.any():
|
| 879 |
input_ids[i, sel_idx[mask_choice]] = mask_token_id
|
| 880 |
if rand_choice.any():
|
| 881 |
rand_ids = torch.randint(0, getattr(tokenizer, "vocab_size", 300), (rand_choice.sum().item(),), dtype=torch.long, device=dev)
|
| 882 |
input_ids[i, sel_idx[rand_choice]] = rand_ids
|
| 883 |
-
|
|
|
|
| 884 |
|
| 885 |
return b
|
| 886 |
|
| 887 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
mm = {}
|
| 889 |
-
if
|
| 890 |
-
gm = masked_batch[
|
| 891 |
-
mm[
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
return mm
|
| 902 |
|
| 903 |
-
|
| 904 |
-
#
|
| 905 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
model.eval()
|
| 907 |
total_loss = 0.0
|
| 908 |
total_examples = 0
|
| 909 |
acc_sum = 0.0
|
| 910 |
-
top5_sum = 0.0
|
| 911 |
-
mrr_sum = 0.0
|
| 912 |
-
mean_pos_logit_sum = 0.0
|
| 913 |
-
mean_neg_logit_sum = 0.0
|
| 914 |
f1_sum = 0.0
|
| 915 |
|
| 916 |
with torch.no_grad():
|
| 917 |
for batch in val_loader:
|
| 918 |
-
masked_batch = mask_batch_for_modality(batch, mask_target, p_mask=P_MASK)
|
| 919 |
-
|
|
|
|
| 920 |
for k in masked_batch:
|
| 921 |
for subk in masked_batch[k]:
|
| 922 |
if isinstance(masked_batch[k][subk], torch.Tensor):
|
| 923 |
masked_batch[k][subk] = masked_batch[k][subk].to(device)
|
|
|
|
| 924 |
mm_in = mm_batch_to_model_input(masked_batch)
|
| 925 |
embs = model.encode(mm_in)
|
|
|
|
| 926 |
if mask_target not in embs:
|
| 927 |
continue
|
|
|
|
| 928 |
target = embs[mask_target]
|
| 929 |
other_keys = [k for k in embs.keys() if k != mask_target]
|
| 930 |
if len(other_keys) == 0:
|
| 931 |
continue
|
|
|
|
| 932 |
anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
|
| 933 |
logits = torch.matmul(anchor, target.T) / model.temperature
|
|
|
|
| 934 |
B = logits.size(0)
|
| 935 |
labels = torch.arange(B, device=logits.device)
|
|
|
|
| 936 |
loss = F.cross_entropy(logits, labels)
|
| 937 |
total_loss += loss.item() * B
|
| 938 |
total_examples += B
|
|
@@ -941,46 +1098,14 @@ def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader, device, m
|
|
| 941 |
acc = (preds == labels).float().mean().item()
|
| 942 |
acc_sum += acc * B
|
| 943 |
|
| 944 |
-
|
| 945 |
-
topk = min(5, B)
|
| 946 |
-
topk_indices = torch.topk(logits, k=topk, dim=1).indices
|
| 947 |
-
hits_topk = (topk_indices == labels.unsqueeze(1)).any(dim=1).float().mean().item()
|
| 948 |
-
top5_sum += hits_topk * B
|
| 949 |
-
else:
|
| 950 |
-
top5_sum += acc * B
|
| 951 |
-
|
| 952 |
-
sorted_desc = torch.argsort(logits, dim=1, descending=True)
|
| 953 |
-
positions = (sorted_desc == labels.unsqueeze(1)).nonzero(as_tuple=False)
|
| 954 |
-
ranks = torch.zeros(B, device=logits.device).float()
|
| 955 |
-
if positions.numel() > 0:
|
| 956 |
-
for p in positions:
|
| 957 |
-
i, pos = int(p[0].item()), int(p[1].item())
|
| 958 |
-
ranks[i] = pos + 1.0
|
| 959 |
-
ranks_nonzero = ranks.clone()
|
| 960 |
-
ranks_nonzero[ranks_nonzero == 0] = float('inf')
|
| 961 |
-
mrr = (1.0 / ranks_nonzero).mean().item()
|
| 962 |
-
mrr_sum += mrr * B
|
| 963 |
-
|
| 964 |
-
pos_logits = logits[torch.arange(B), labels]
|
| 965 |
-
neg_logits = logits.clone()
|
| 966 |
-
neg_logits[torch.arange(B), labels] = float('-inf')
|
| 967 |
-
neg_mask = neg_logits != float('-inf')
|
| 968 |
-
if neg_mask.any():
|
| 969 |
-
row_counts = neg_mask.sum(dim=1).clamp(min=1).float()
|
| 970 |
-
sum_neg_per_row = neg_logits.masked_fill(~neg_mask, 0.0).sum(dim=1)
|
| 971 |
-
mean_neg = (sum_neg_per_row / row_counts).mean().item()
|
| 972 |
-
else:
|
| 973 |
-
mean_neg = 0.0
|
| 974 |
-
mean_pos_logit_sum += pos_logits.mean().item() * B
|
| 975 |
-
mean_neg_logit_sum += mean_neg * B
|
| 976 |
-
|
| 977 |
try:
|
| 978 |
labels_np = labels.cpu().numpy()
|
| 979 |
preds_np = preds.cpu().numpy()
|
| 980 |
if len(np.unique(labels_np)) < 2:
|
| 981 |
batch_f1 = float(acc)
|
| 982 |
else:
|
| 983 |
-
batch_f1 = f1_score(labels_np, preds_np, average=
|
| 984 |
except Exception:
|
| 985 |
batch_f1 = float(acc)
|
| 986 |
f1_sum += batch_f1 * B
|
|
@@ -988,18 +1113,29 @@ def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader, device, m
|
|
| 988 |
if total_examples == 0:
|
| 989 |
return {"eval_loss": float("nan"), "eval_accuracy": 0.0, "eval_f1_weighted": 0.0}
|
| 990 |
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
|
|
|
|
|
|
|
|
|
| 994 |
|
| 995 |
-
|
|
|
|
|
|
|
| 996 |
|
| 997 |
-
# ---------------------------
|
| 998 |
-
# HF wrapper / Trainer integration (kept same as your part 2, uses lazy loaders)
|
| 999 |
class HFMultimodalModule(nn.Module):
|
| 1000 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
super().__init__()
|
| 1002 |
self.mm = mm_model
|
|
|
|
| 1003 |
|
| 1004 |
def forward(self, **kwargs):
|
| 1005 |
if "batch" in kwargs:
|
|
@@ -1012,16 +1148,23 @@ class HFMultimodalModule(nn.Module):
|
|
| 1012 |
batch = {k: found[k] for k in found}
|
| 1013 |
mask_target = kwargs.get("mask_target", "fp")
|
| 1014 |
else:
|
| 1015 |
-
raise ValueError(
|
| 1016 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
device = next(self.parameters()).device
|
| 1018 |
for k in masked_batch:
|
| 1019 |
for subk in list(masked_batch[k].keys()):
|
| 1020 |
val = masked_batch[k][subk]
|
| 1021 |
if isinstance(val, torch.Tensor):
|
| 1022 |
masked_batch[k][subk] = val.to(device)
|
|
|
|
| 1023 |
mm_in = mm_batch_to_model_input(masked_batch)
|
| 1024 |
loss, info = self.mm(mm_in, mask_target)
|
|
|
|
| 1025 |
logits = None
|
| 1026 |
labels = None
|
| 1027 |
try:
|
|
@@ -1039,6 +1182,7 @@ class HFMultimodalModule(nn.Module):
|
|
| 1039 |
print("Warning: failed to compute logits/labels inside HFMultimodalModule.forward:", e)
|
| 1040 |
logits = None
|
| 1041 |
labels = None
|
|
|
|
| 1042 |
eval_loss = loss.detach() if isinstance(loss, torch.Tensor) else torch.tensor(float(loss), device=device)
|
| 1043 |
out = {"loss": loss, "eval_loss": eval_loss}
|
| 1044 |
if logits is not None:
|
|
@@ -1048,11 +1192,15 @@ class HFMultimodalModule(nn.Module):
|
|
| 1048 |
out["mm_info"] = info
|
| 1049 |
return out
|
| 1050 |
|
| 1051 |
-
hf_model = HFMultimodalModule(multimodal_model)
|
| 1052 |
-
hf_model.to(device)
|
| 1053 |
|
| 1054 |
class ContrastiveDataCollator:
|
| 1055 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1056 |
self.mask_prob = mask_prob
|
| 1057 |
self.modalities = modalities if modalities is not None else ["gine", "schnet", "fp", "psmiles"]
|
| 1058 |
|
|
@@ -1061,22 +1209,30 @@ class ContrastiveDataCollator:
|
|
| 1061 |
collated = features
|
| 1062 |
mask_target = random.choice([m for m in self.modalities if m in collated])
|
| 1063 |
return {"batch": collated, "mask_target": mask_target}
|
|
|
|
| 1064 |
if isinstance(features, (list, tuple)) and len(features) > 0:
|
| 1065 |
first = features[0]
|
| 1066 |
-
if isinstance(first, dict) and
|
| 1067 |
collated = multimodal_collate(list(features))
|
| 1068 |
mask_target = random.choice([m for m in self.modalities if m in collated])
|
| 1069 |
return {"batch": collated, "mask_target": mask_target}
|
| 1070 |
-
|
| 1071 |
-
|
|
|
|
| 1072 |
mask_target = first.get("mask_target", random.choice([m for m in self.modalities if m in collated]))
|
| 1073 |
return {"batch": collated, "mask_target": mask_target}
|
|
|
|
| 1074 |
print("ContrastiveDataCollator received unexpected 'features' shape/type.")
|
| 1075 |
raise ValueError("ContrastiveDataCollator could not collate input. Expected list[dict] with 'gine' key or already-collated dict.")
|
| 1076 |
|
| 1077 |
-
data_collator = ContrastiveDataCollator(mask_prob=P_MASK)
|
| 1078 |
|
| 1079 |
class VerboseTrainingCallback(TrainerCallback):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1080 |
def __init__(self, patience: int = 10):
|
| 1081 |
self.start_time = time.time()
|
| 1082 |
self.epoch_start_time = time.time()
|
|
@@ -1104,42 +1260,45 @@ class VerboseTrainingCallback(TrainerCallback):
|
|
| 1104 |
|
| 1105 |
def on_train_begin(self, args, state, control, **kwargs):
|
| 1106 |
self.start_time = time.time()
|
| 1107 |
-
print(""
|
| 1108 |
-
print("
|
| 1109 |
-
print("="*80)
|
| 1110 |
-
|
|
|
|
| 1111 |
if model is not None:
|
| 1112 |
total_params = sum(p.numel() for p in model.parameters())
|
| 1113 |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1114 |
non_trainable_params = total_params - trainable_params
|
| 1115 |
-
print(
|
| 1116 |
print(f" Total Parameters: {total_params:,}")
|
| 1117 |
print(f" Trainable Parameters: {trainable_params:,}")
|
| 1118 |
print(f" Non-trainable Parameters: {non_trainable_params:,}")
|
| 1119 |
print(f" Training Progress: 0/{args.num_train_epochs} epochs")
|
| 1120 |
-
|
|
|
|
| 1121 |
|
| 1122 |
def on_epoch_begin(self, args, state, control, **kwargs):
|
| 1123 |
self.epoch_start_time = time.time()
|
| 1124 |
current_epoch = state.epoch if state is not None else 0.0
|
| 1125 |
-
print(f"
|
| 1126 |
|
| 1127 |
def on_epoch_end(self, args, state, control, **kwargs):
|
| 1128 |
train_loss = None
|
| 1129 |
for log in reversed(state.log_history):
|
| 1130 |
-
if isinstance(log, dict) and
|
| 1131 |
-
train_loss = log[
|
| 1132 |
break
|
| 1133 |
self._last_train_loss = train_loss
|
| 1134 |
|
| 1135 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 1136 |
-
if logs is not None and
|
| 1137 |
current_step = state.global_step
|
| 1138 |
current_epoch = state.epoch
|
| 1139 |
try:
|
| 1140 |
steps_per_epoch = max(1, len(train_loader) // args.gradient_accumulation_steps)
|
| 1141 |
except Exception:
|
| 1142 |
steps_per_epoch = 1
|
|
|
|
| 1143 |
if current_step % max(1, steps_per_epoch // 10) == 0:
|
| 1144 |
progress = current_epoch + (current_step % steps_per_epoch) / steps_per_epoch
|
| 1145 |
print(f" Step {current_step:4d} | Epoch {progress:.1f} | Train Loss: {logs['loss']:.6f}")
|
|
@@ -1147,54 +1306,63 @@ class VerboseTrainingCallback(TrainerCallback):
|
|
| 1147 |
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 1148 |
current_epoch = state.epoch if state is not None else 0.0
|
| 1149 |
epoch_time = time.time() - self.epoch_start_time
|
| 1150 |
-
|
|
|
|
| 1151 |
hf_eval_loss = None
|
| 1152 |
hf_train_loss = self._last_train_loss
|
|
|
|
| 1153 |
if hf_metrics is not None:
|
| 1154 |
-
hf_eval_loss = hf_metrics.get(
|
| 1155 |
if hf_train_loss is None:
|
| 1156 |
-
hf_train_loss = hf_metrics.get(
|
|
|
|
| 1157 |
cl_metrics = {}
|
| 1158 |
try:
|
| 1159 |
-
model = kwargs.get(
|
| 1160 |
if model is not None:
|
| 1161 |
cl_model = model.mm if hasattr(model, "mm") else model
|
| 1162 |
-
cl_metrics = evaluate_multimodal(cl_model, val_loader, device, mask_target="fp")
|
| 1163 |
else:
|
| 1164 |
-
cl_metrics = evaluate_multimodal(multimodal_model, val_loader, device, mask_target="fp")
|
| 1165 |
except Exception as e:
|
| 1166 |
print("Warning: evaluate_multimodal inside callback failed:", e)
|
|
|
|
| 1167 |
if hf_eval_loss is None:
|
| 1168 |
-
hf_eval_loss = cl_metrics.get(
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
|
|
|
|
|
|
| 1172 |
if hf_train_loss is not None:
|
| 1173 |
try:
|
| 1174 |
print(f" Train Loss (HF reported): {hf_train_loss:.6f}")
|
| 1175 |
except Exception:
|
| 1176 |
print(f" Train Loss (HF reported): {hf_train_loss}")
|
| 1177 |
else:
|
| 1178 |
-
print(
|
|
|
|
| 1179 |
if hf_eval_loss is not None:
|
| 1180 |
try:
|
| 1181 |
print(f" Eval Loss (HF reported): {hf_eval_loss:.6f}")
|
| 1182 |
except Exception:
|
| 1183 |
print(f" Eval Loss (HF reported): {hf_eval_loss}")
|
| 1184 |
else:
|
| 1185 |
-
print(
|
|
|
|
| 1186 |
if isinstance(val_acc, float):
|
| 1187 |
print(f" Eval Acc (CL evaluator): {val_acc:.6f}")
|
| 1188 |
else:
|
| 1189 |
print(f" Eval Acc (CL evaluator): {val_acc}")
|
|
|
|
| 1190 |
if isinstance(val_f1, float):
|
| 1191 |
print(f" Eval F1 Weighted (CL evaluator): {val_f1:.6f}")
|
| 1192 |
else:
|
| 1193 |
print(f" Eval F1 Weighted (CL evaluator): {val_f1}")
|
| 1194 |
-
|
| 1195 |
-
|
|
|
|
| 1196 |
if current_val < self.best_val_loss - 1e-6:
|
| 1197 |
-
improved = True
|
| 1198 |
self.best_val_loss = current_val
|
| 1199 |
self.best_epoch = current_epoch
|
| 1200 |
self.epochs_no_improve = 0
|
|
@@ -1204,9 +1372,11 @@ class VerboseTrainingCallback(TrainerCallback):
|
|
| 1204 |
print("Warning: saving best model failed:", e)
|
| 1205 |
else:
|
| 1206 |
self.epochs_no_improve += 1
|
|
|
|
| 1207 |
if self.epochs_no_improve >= self.patience:
|
| 1208 |
print(f"Early stopping: no improvement in val_loss for {self.patience} epochs.")
|
| 1209 |
control.should_training_stop = True
|
|
|
|
| 1210 |
print(f" Epoch Training Time: {epoch_time:.2f}s")
|
| 1211 |
print(f" Best Val Loss so far: {self.best_val_loss}")
|
| 1212 |
print(f" Epochs since improvement: {self.epochs_no_improve}/{self.patience}")
|
|
@@ -1214,20 +1384,26 @@ class VerboseTrainingCallback(TrainerCallback):
|
|
| 1214 |
|
| 1215 |
def on_train_end(self, args, state, control, **kwargs):
|
| 1216 |
total_time = time.time() - self.start_time
|
| 1217 |
-
print(""
|
| 1218 |
-
print("
|
| 1219 |
-
print("="*80)
|
| 1220 |
print(f" Total Training Time: {total_time:.2f}s")
|
| 1221 |
if state is not None:
|
| 1222 |
try:
|
| 1223 |
print(f" Total Epochs Completed: {state.epoch + 1:.1f}")
|
| 1224 |
except Exception:
|
| 1225 |
pass
|
| 1226 |
-
print("="*80)
|
|
|
|
| 1227 |
|
| 1228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1229 |
|
| 1230 |
-
class CLTrainer(HfTrainer):
|
| 1231 |
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
|
| 1232 |
try:
|
| 1233 |
metrics = super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) or {}
|
|
@@ -1237,7 +1413,7 @@ class CLTrainer(HfTrainer):
|
|
| 1237 |
traceback.print_exc()
|
| 1238 |
try:
|
| 1239 |
cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
|
| 1240 |
-
cl_metrics = evaluate_multimodal(cl_model, val_loader, device, mask_target="fp")
|
| 1241 |
metrics = {k: float(v) if isinstance(v, (float, int, np.floating, np.integer)) else v for k, v in cl_metrics.items()}
|
| 1242 |
metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
|
| 1243 |
except Exception as e2:
|
|
@@ -1245,32 +1421,39 @@ class CLTrainer(HfTrainer):
|
|
| 1245 |
traceback.print_exc()
|
| 1246 |
metrics = {"eval_loss": float("nan"), "epoch": float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else 0.0}
|
| 1247 |
return metrics
|
|
|
|
| 1248 |
try:
|
| 1249 |
cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
|
| 1250 |
-
cl_metrics = evaluate_multimodal(cl_model, val_loader, device, mask_target="fp")
|
| 1251 |
except Exception as e:
|
| 1252 |
print("Warning: evaluate_multimodal failed inside CLTrainer.evaluate():", e)
|
| 1253 |
cl_metrics = {}
|
|
|
|
| 1254 |
for k, v in cl_metrics.items():
|
| 1255 |
try:
|
| 1256 |
metrics[k] = float(v)
|
| 1257 |
except Exception:
|
| 1258 |
metrics[k] = v
|
| 1259 |
-
|
|
|
|
| 1260 |
try:
|
| 1261 |
-
metrics[
|
| 1262 |
except Exception:
|
| 1263 |
-
metrics[
|
|
|
|
| 1264 |
if "epoch" not in metrics:
|
| 1265 |
metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
|
|
|
|
| 1266 |
return metrics
|
| 1267 |
|
| 1268 |
def _save(self, output_dir: str):
|
| 1269 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 1270 |
try:
|
| 1271 |
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
| 1272 |
except Exception:
|
| 1273 |
pass
|
|
|
|
| 1274 |
try:
|
| 1275 |
model_to_save = self.model.mm if hasattr(self.model, "mm") else self.model
|
| 1276 |
torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
|
@@ -1282,10 +1465,12 @@ class CLTrainer(HfTrainer):
|
|
| 1282 |
raise e
|
| 1283 |
except Exception as e2:
|
| 1284 |
print("Warning: failed to save model state_dict:", e2)
|
|
|
|
| 1285 |
try:
|
| 1286 |
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| 1287 |
except Exception:
|
| 1288 |
pass
|
|
|
|
| 1289 |
try:
|
| 1290 |
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| 1291 |
except Exception:
|
|
@@ -1295,14 +1480,17 @@ class CLTrainer(HfTrainer):
|
|
| 1295 |
best_ckpt = self.state.best_model_checkpoint
|
| 1296 |
if not best_ckpt:
|
| 1297 |
return
|
|
|
|
| 1298 |
candidate = os.path.join(best_ckpt, "pytorch_model.bin")
|
| 1299 |
if not os.path.exists(candidate):
|
| 1300 |
candidate = os.path.join(best_ckpt, "model.bin")
|
| 1301 |
if not os.path.exists(candidate):
|
| 1302 |
candidate = None
|
|
|
|
| 1303 |
if candidate is None:
|
| 1304 |
print(f"CLTrainer._load_best_model(): no compatible pytorch_model.bin found in {best_ckpt}; skipping load.")
|
| 1305 |
return
|
|
|
|
| 1306 |
try:
|
| 1307 |
state_dict = torch.load(candidate, map_location=self.args.device)
|
| 1308 |
model_to_load = self.model.mm if hasattr(self.model, "mm") else self.model
|
|
@@ -1312,97 +1500,224 @@ class CLTrainer(HfTrainer):
|
|
| 1312 |
print("CLTrainer._load_best_model: failed to load state_dict using torch.load:", e)
|
| 1313 |
return
|
| 1314 |
|
| 1315 |
-
callback = VerboseTrainingCallback(patience=10)
|
| 1316 |
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
train_dataset=train_subset,
|
| 1321 |
-
eval_dataset=val_subset,
|
| 1322 |
-
data_collator=data_collator,
|
| 1323 |
-
callbacks=[callback],
|
| 1324 |
-
)
|
| 1325 |
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1331 |
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
|
| 1335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1336 |
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
|
|
|
| 1340 |
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
print(f" Non-trainable Parameters: {non_trainable_params:,}")
|
| 1345 |
|
| 1346 |
-
|
| 1347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1348 |
|
| 1349 |
-
#
|
| 1350 |
-
|
| 1351 |
-
if USE_CUDA:
|
| 1352 |
try:
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1356 |
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
training_end_time = time.time()
|
| 1362 |
|
| 1363 |
-
|
| 1364 |
-
|
| 1365 |
-
os.makedirs(BEST_MULTIMODAL_DIR, exist_ok=True)
|
| 1366 |
|
| 1367 |
-
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
-
|
| 1373 |
-
print(f"
|
| 1374 |
-
|
| 1375 |
-
print("
|
| 1376 |
-
|
| 1377 |
-
|
| 1378 |
-
|
| 1379 |
-
|
| 1380 |
-
|
| 1381 |
-
trainer._load_best_model()
|
| 1382 |
-
final_metrics = trainer.evaluate(eval_dataset=val_subset)
|
| 1383 |
-
else:
|
| 1384 |
-
final_metrics = evaluate_multimodal(multimodal_model, val_loader, device, mask_target="fp")
|
| 1385 |
-
except Exception as e:
|
| 1386 |
-
print("Warning: final evaluation via trainer.evaluate failed, falling back to direct evaluate_multimodal:", e)
|
| 1387 |
-
final_metrics = evaluate_multimodal(multimodal_model, val_loader, device, mask_target="fp")
|
| 1388 |
-
|
| 1389 |
-
print("\n" + "="*80)
|
| 1390 |
-
print("🏁 FINAL TRAINING RESULTS")
|
| 1391 |
-
print("="*80)
|
| 1392 |
-
training_time = training_end_time - training_start_time
|
| 1393 |
-
print(f"Total Training Time: {training_time:.2f}s")
|
| 1394 |
-
best_ckpt = trainer.state.best_model_checkpoint if hasattr(trainer.state, 'best_model_checkpoint') else None
|
| 1395 |
-
if best_ckpt:
|
| 1396 |
-
print(f"Best Checkpoint: {best_ckpt}")
|
| 1397 |
-
else:
|
| 1398 |
-
print("Best Checkpoint: (none saved)")
|
| 1399 |
-
|
| 1400 |
-
hf_eval_loss = final_metrics.get('eval_loss', float('nan'))
|
| 1401 |
-
hf_eval_acc = final_metrics.get('eval_accuracy', 0.0)
|
| 1402 |
-
hf_eval_f1 = final_metrics.get('eval_f1_weighted', 0.0)
|
| 1403 |
-
print(f"Val Loss (HF reported / trainer.evaluate): {hf_eval_loss:.4f}")
|
| 1404 |
-
print(f"Val Acc (CL evaluator): {hf_eval_acc:.4f}")
|
| 1405 |
-
print(f"Val F1 Weighted (CL evaluator): {hf_eval_f1:.4f}")
|
| 1406 |
-
print(f"Total Trainable Params: {trainable_params:,}")
|
| 1407 |
-
print(f"Total Non-trainable Params: {non_trainable_params:,}")
|
| 1408 |
-
print("="*80)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PolyFusion - CL.py
|
| 3 |
+
Multimodal contrastive pretraining script (DeBERTaV2 + GINE + SchNet + Transformer).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
import csv
|
|
|
|
| 27 |
import torch.nn.functional as F
|
| 28 |
from torch.utils.data import Dataset, DataLoader
|
| 29 |
|
| 30 |
+
# Shared model utilities
|
| 31 |
from GINE import GineEncoder, match_edge_attr_to_index, safe_get
|
| 32 |
from SchNet import NodeSchNetWrapper
|
| 33 |
from Transformer import PooledFingerprintEncoder as FingerprintEncoder
|
|
|
|
| 35 |
|
| 36 |
# HF Trainer & Transformers
|
| 37 |
from transformers import TrainingArguments, Trainer
|
|
|
|
| 38 |
from transformers.trainer_callback import TrainerCallback
|
| 39 |
|
| 40 |
from sklearn.model_selection import train_test_split
|
| 41 |
+
from sklearn.metrics import f1_score
|
| 42 |
+
|
| 43 |
+
# =============================================================================
|
| 44 |
+
# Configuration (paths are placeholders; update for your environment)
|
| 45 |
+
# =============================================================================
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
P_MASK = 0.15
|
| 48 |
MAX_ATOMIC_Z = 85
|
| 49 |
MASK_ATOM_ID = MAX_ATOMIC_Z + 1
|
|
|
|
| 53 |
EDGE_EMB_DIM = 300
|
| 54 |
NUM_GNN_LAYERS = 5
|
| 55 |
|
| 56 |
+
# SchNet params
|
| 57 |
SCHNET_NUM_GAUSSIANS = 50
|
| 58 |
SCHNET_NUM_INTERACTIONS = 6
|
| 59 |
SCHNET_CUTOFF = 10.0
|
| 60 |
SCHNET_MAX_NEIGHBORS = 64
|
| 61 |
SCHNET_HIDDEN = 600
|
| 62 |
|
| 63 |
+
# Fingerprint Transformer params
|
| 64 |
FP_LENGTH = 2048
|
| 65 |
+
MASK_TOKEN_ID_FP = 2
|
| 66 |
VOCAB_SIZE_FP = 3
|
| 67 |
|
| 68 |
+
# DeBERTaV2 params
|
| 69 |
DEBERTA_HIDDEN = 600
|
| 70 |
PSMILES_MAX_LEN = 128
|
| 71 |
|
| 72 |
# Contrastive params
|
| 73 |
TEMPERATURE = 0.07
|
| 74 |
+
REC_LOSS_WEIGHT = 1.0 # Reconstruction loss weight
|
| 75 |
+
|
| 76 |
+
# Data / preprocessing
|
| 77 |
+
CSV_PATH = "/path/to/polymer_structures_unified_processed.csv"
|
| 78 |
+
TARGET_ROWS = 2000000
|
| 79 |
+
CHUNKSIZE = 50000
|
| 80 |
+
PREPROC_DIR = "/path/to/preprocessed_samples"
|
| 81 |
|
| 82 |
+
# Tokenizer assets
|
| 83 |
+
SPM_MODEL = "/path/to/spm.model"
|
| 84 |
|
| 85 |
+
# Outputs / checkpoints
|
| 86 |
OUTPUT_DIR = "/path/to/multimodal_output"
|
|
|
|
| 87 |
BEST_GINE_DIR = "/path/to/gin_output/best"
|
| 88 |
BEST_SCHNET_DIR = "/path/to/schnet_output/best"
|
| 89 |
BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
|
| 90 |
BEST_PSMILES_DIR = "/path/to/polybert_output/best"
|
| 91 |
|
| 92 |
+
|
| 93 |
+
# =============================================================================
|
| 94 |
+
# Reproducibility + device
|
| 95 |
+
# =============================================================================
|
| 96 |
+
|
| 97 |
+
def get_device() -> torch.device:
|
| 98 |
+
"""Select CUDA if available (respects CUDA_VISIBLE_DEVICES), else CPU."""
|
| 99 |
+
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def set_seed(seed: int = 42) -> None:
|
| 103 |
+
"""Set Python/Numpy/Torch seeds for deterministic-ish behavior."""
|
| 104 |
+
random.seed(seed)
|
| 105 |
+
np.random.seed(seed)
|
| 106 |
+
torch.manual_seed(seed)
|
| 107 |
+
if torch.cuda.is_available():
|
| 108 |
+
torch.cuda.manual_seed_all(seed)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# =============================================================================
|
| 112 |
+
# Optional helper
|
| 113 |
+
# =============================================================================
|
| 114 |
+
|
| 115 |
+
def bfs_distances_to_visible(
|
| 116 |
+
edge_index: torch.Tensor,
|
| 117 |
+
num_nodes: int,
|
| 118 |
+
masked_idx: np.ndarray,
|
| 119 |
+
visible_idx: np.ndarray,
|
| 120 |
+
k_anchors: int
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Compute shortest-path distances from masked nodes to nearest visible anchors.
|
| 124 |
+
|
| 125 |
+
This helper was present in your original script. It remains here unchanged in
|
| 126 |
+
behavior (and may be useful for future masking/anchor variants), but it is
|
| 127 |
+
not required for the current CL objective.
|
| 128 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
|
| 130 |
selected_mask = np.zeros((num_nodes, k_anchors), dtype=np.bool_)
|
| 131 |
+
|
| 132 |
if edge_index is None or edge_index.numel() == 0:
|
| 133 |
return selected_dists, selected_mask
|
| 134 |
+
|
| 135 |
src = edge_index[0].tolist()
|
| 136 |
dst = edge_index[1].tolist()
|
| 137 |
+
|
| 138 |
adj = [[] for _ in range(num_nodes)]
|
| 139 |
for u, v in zip(src, dst):
|
| 140 |
if 0 <= u < num_nodes and 0 <= v < num_nodes:
|
| 141 |
adj[u].append(v)
|
| 142 |
+
|
| 143 |
+
visible_set = (
|
| 144 |
+
set(visible_idx.tolist())
|
| 145 |
+
if isinstance(visible_idx, (np.ndarray, list))
|
| 146 |
+
else set(visible_idx.cpu().tolist())
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
for a in np.atleast_1d(masked_idx).tolist():
|
| 150 |
if a < 0 or a >= num_nodes:
|
| 151 |
continue
|
| 152 |
+
|
| 153 |
q = [a]
|
| 154 |
visited = [-1] * num_nodes
|
| 155 |
visited[a] = 0
|
| 156 |
head = 0
|
| 157 |
found = []
|
| 158 |
+
|
| 159 |
while head < len(q) and len(found) < k_anchors:
|
| 160 |
+
u = q[head]
|
| 161 |
+
head += 1
|
| 162 |
for v in adj[u]:
|
| 163 |
if visited[v] == -1:
|
| 164 |
visited[v] = visited[u] + 1
|
|
|
|
| 167 |
found.append((visited[v], v))
|
| 168 |
if len(found) >= k_anchors:
|
| 169 |
break
|
| 170 |
+
|
| 171 |
if len(found) > 0:
|
| 172 |
found.sort(key=lambda x: x[0])
|
| 173 |
k = min(k_anchors, len(found))
|
| 174 |
for i in range(k):
|
| 175 |
selected_dists[a, i] = float(found[i][0])
|
| 176 |
selected_mask[a, i] = True
|
| 177 |
+
|
| 178 |
return selected_dists, selected_mask
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
# =============================================================================
|
| 182 |
+
# Preprocessing (streaming to disk to avoid large memory spikes)
|
| 183 |
+
# =============================================================================
|
| 184 |
+
|
| 185 |
+
def ensure_dir(path: str) -> None:
|
| 186 |
+
"""Create a directory if it doesn't exist."""
|
| 187 |
+
os.makedirs(path, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def prepare_or_load_data_streaming(
|
| 191 |
+
csv_path: str,
|
| 192 |
+
preproc_dir: str,
|
| 193 |
+
target_rows: int = TARGET_ROWS,
|
| 194 |
+
chunksize: int = CHUNKSIZE
|
| 195 |
+
) -> List[str]:
|
| 196 |
+
"""
|
| 197 |
+
Prepare per-sample serialized files (torch .pt) for lazy loading.
|
| 198 |
+
|
| 199 |
+
- If `preproc_dir` already contains `sample_*.pt`, reuse them.
|
| 200 |
+
- Else stream the CSV in chunks and write `sample_{idx:08d}.pt` files.
|
| 201 |
+
"""
|
| 202 |
+
ensure_dir(preproc_dir)
|
| 203 |
+
|
| 204 |
+
existing = sorted([p for p in Path(preproc_dir).glob("sample_*.pt")])
|
| 205 |
if len(existing) > 0:
|
| 206 |
+
print(f"Found {len(existing)} preprocessed sample files in {preproc_dir}; reusing those (no reparse).")
|
| 207 |
return [str(p) for p in existing]
|
| 208 |
|
| 209 |
print("No existing per-sample preprocessed folder found. Parsing CSV chunked and writing per-sample files (streaming).")
|
| 210 |
+
rows_written = 0
|
| 211 |
sample_idx = 0
|
| 212 |
|
| 213 |
+
for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize):
|
|
|
|
| 214 |
has_graph = "graph" in chunk.columns
|
| 215 |
has_geometry = "geometry" in chunk.columns
|
| 216 |
has_fp = "fingerprints" in chunk.columns
|
| 217 |
has_psmiles = "psmiles" in chunk.columns
|
| 218 |
|
| 219 |
for i_row in range(len(chunk)):
|
| 220 |
+
if rows_written >= target_rows:
|
| 221 |
break
|
| 222 |
+
|
| 223 |
row = chunk.iloc[i_row]
|
| 224 |
|
| 225 |
+
# Per-row modality payloads (None if missing)
|
| 226 |
gine_sample = None
|
| 227 |
schnet_sample = None
|
| 228 |
fp_sample = None
|
| 229 |
psmiles_raw = None
|
| 230 |
|
| 231 |
+
# -------- Graph / GINE modality --------
|
| 232 |
if has_graph:
|
| 233 |
val = row.get("graph", "")
|
| 234 |
try:
|
| 235 |
+
graph_field = (
|
| 236 |
+
json.loads(val)
|
| 237 |
+
if isinstance(val, str) and val.strip() != ""
|
| 238 |
+
else (val if not isinstance(val, str) else None)
|
| 239 |
+
)
|
| 240 |
except Exception:
|
| 241 |
graph_field = None
|
| 242 |
+
|
| 243 |
if graph_field:
|
| 244 |
node_features = safe_get(graph_field, "node_features", None)
|
| 245 |
if node_features:
|
| 246 |
atomic_nums = []
|
| 247 |
chirality_vals = []
|
| 248 |
formal_charges = []
|
| 249 |
+
|
| 250 |
for nf in node_features:
|
| 251 |
an = safe_get(nf, "atomic_num", None)
|
| 252 |
if an is None:
|
| 253 |
an = safe_get(nf, "atomic_number", 0)
|
| 254 |
ch = safe_get(nf, "chirality", 0)
|
| 255 |
fc = safe_get(nf, "formal_charge", 0)
|
| 256 |
+
|
| 257 |
try:
|
| 258 |
atomic_nums.append(int(an))
|
| 259 |
except Exception:
|
| 260 |
atomic_nums.append(0)
|
| 261 |
+
|
| 262 |
chirality_vals.append(float(ch))
|
| 263 |
formal_charges.append(float(fc))
|
| 264 |
+
|
| 265 |
edge_indices_raw = safe_get(graph_field, "edge_indices", None)
|
| 266 |
edge_features_raw = safe_get(graph_field, "edge_features", None)
|
| 267 |
+
|
| 268 |
edge_index = None
|
| 269 |
edge_attr = None
|
| 270 |
+
|
| 271 |
+
# Handle missing edge_indices via adjacency_matrix
|
| 272 |
if edge_indices_raw is None:
|
| 273 |
adj_mat = safe_get(graph_field, "adjacency_matrix", None)
|
| 274 |
if adj_mat:
|
| 275 |
+
srcs, dsts = [], []
|
|
|
|
| 276 |
for i_r, row_adj in enumerate(adj_mat):
|
| 277 |
for j, val2 in enumerate(row_adj):
|
| 278 |
if val2:
|
| 279 |
+
srcs.append(i_r)
|
| 280 |
+
dsts.append(j)
|
| 281 |
if len(srcs) > 0:
|
| 282 |
edge_index = [srcs, dsts]
|
| 283 |
E = len(srcs)
|
| 284 |
edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
|
| 285 |
else:
|
| 286 |
+
# edge_indices_raw can be:
|
| 287 |
+
# - list of pairs [[u,v], ...]
|
| 288 |
+
# - two lists [[srcs], [dsts]]
|
| 289 |
srcs, dsts = [], []
|
| 290 |
+
|
| 291 |
if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list):
|
|
|
|
| 292 |
first = edge_indices_raw[0]
|
| 293 |
if len(first) == 2 and isinstance(first[0], int):
|
| 294 |
+
# list of pairs
|
| 295 |
try:
|
| 296 |
srcs = [int(p[0]) for p in edge_indices_raw]
|
| 297 |
dsts = [int(p[1]) for p in edge_indices_raw]
|
| 298 |
except Exception:
|
| 299 |
srcs, dsts = [], []
|
| 300 |
else:
|
| 301 |
+
# two lists
|
| 302 |
try:
|
| 303 |
srcs = [int(x) for x in edge_indices_raw[0]]
|
| 304 |
dsts = [int(x) for x in edge_indices_raw[1]]
|
| 305 |
except Exception:
|
| 306 |
srcs, dsts = [], []
|
| 307 |
+
|
| 308 |
+
if len(srcs) == 0 and isinstance(edge_indices_raw, list) and all(
|
| 309 |
+
isinstance(p, (list, tuple)) and len(p) == 2 for p in edge_indices_raw
|
| 310 |
+
):
|
| 311 |
srcs = [int(p[0]) for p in edge_indices_raw]
|
| 312 |
dsts = [int(p[1]) for p in edge_indices_raw]
|
| 313 |
+
|
| 314 |
if len(srcs) > 0:
|
| 315 |
edge_index = [srcs, dsts]
|
| 316 |
+
|
| 317 |
if edge_features_raw and isinstance(edge_features_raw, list):
|
| 318 |
+
bond_types, stereos, is_conjs = [], [], []
|
|
|
|
|
|
|
| 319 |
for ef in edge_features_raw:
|
| 320 |
bt = safe_get(ef, "bond_type", 0)
|
| 321 |
st = safe_get(ef, "stereo", 0)
|
| 322 |
ic = safe_get(ef, "is_conjugated", False)
|
| 323 |
+
bond_types.append(float(bt))
|
| 324 |
+
stereos.append(float(st))
|
| 325 |
+
is_conjs.append(float(1.0 if ic else 0.0))
|
| 326 |
edge_attr = list(zip(bond_types, stereos, is_conjs))
|
| 327 |
else:
|
| 328 |
E = len(srcs)
|
|
|
|
| 337 |
"edge_attr": edge_attr,
|
| 338 |
}
|
| 339 |
|
| 340 |
+
# -------- Geometry / SchNet modality --------
|
| 341 |
if has_geometry and schnet_sample is None:
|
| 342 |
val = row.get("geometry", "")
|
| 343 |
try:
|
| 344 |
+
geom = (
|
| 345 |
+
json.loads(val)
|
| 346 |
+
if isinstance(val, str) and val.strip() != ""
|
| 347 |
+
else (val if not isinstance(val, str) else None)
|
| 348 |
+
)
|
| 349 |
conf = geom.get("best_conformer") if isinstance(geom, dict) else None
|
| 350 |
if conf:
|
| 351 |
atomic = conf.get("atomic_numbers", [])
|
|
|
|
| 355 |
except Exception:
|
| 356 |
schnet_sample = None
|
| 357 |
|
| 358 |
+
# -------- Fingerprint modality --------
|
| 359 |
if has_fp:
|
| 360 |
fpval = row.get("fingerprints", "")
|
| 361 |
if fpval is None or (isinstance(fpval, str) and fpval.strip() == ""):
|
| 362 |
fp_sample = [0] * FP_LENGTH
|
| 363 |
else:
|
| 364 |
+
fp_json = None
|
| 365 |
try:
|
| 366 |
fp_json = json.loads(fpval) if isinstance(fpval, str) else fpval
|
| 367 |
except Exception:
|
|
|
|
| 373 |
if len(bits) < FP_LENGTH:
|
| 374 |
bits += [0] * (FP_LENGTH - len(bits))
|
| 375 |
fp_sample = bits
|
| 376 |
+
|
| 377 |
if fp_sample is None:
|
| 378 |
+
bits = (
|
| 379 |
+
safe_get(fp_json, "morgan_r3_bits", None)
|
| 380 |
+
if isinstance(fp_json, dict)
|
| 381 |
+
else (fp_json if isinstance(fp_json, list) else None)
|
| 382 |
+
)
|
| 383 |
if bits is None:
|
| 384 |
fp_sample = [0] * FP_LENGTH
|
| 385 |
else:
|
|
|
|
| 392 |
normalized.append(1 if int(b) != 0 else 0)
|
| 393 |
else:
|
| 394 |
normalized.append(0)
|
| 395 |
+
|
| 396 |
if len(normalized) >= FP_LENGTH:
|
| 397 |
break
|
| 398 |
+
|
| 399 |
if len(normalized) < FP_LENGTH:
|
| 400 |
normalized.extend([0] * (FP_LENGTH - len(normalized)))
|
| 401 |
fp_sample = normalized[:FP_LENGTH]
|
| 402 |
|
| 403 |
+
# -------- PSMILES modality --------
|
| 404 |
if has_psmiles:
|
| 405 |
s = row.get("psmiles", "")
|
| 406 |
+
psmiles_raw = "" if s is None else str(s)
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
+
# Require at least 2 modalities to keep sample (same logic as your original)
|
| 409 |
+
modalities_present = sum(
|
| 410 |
+
[1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
|
| 411 |
+
)
|
| 412 |
if modalities_present >= 2:
|
| 413 |
sample = {
|
| 414 |
"gine": gine_sample,
|
| 415 |
"schnet": schnet_sample,
|
| 416 |
"fp": fp_sample,
|
| 417 |
+
"psmiles_raw": psmiles_raw,
|
| 418 |
}
|
| 419 |
+
|
| 420 |
+
sample_path = os.path.join(preproc_dir, f"sample_{sample_idx:08d}.pt")
|
| 421 |
try:
|
| 422 |
torch.save(sample, sample_path)
|
| 423 |
except Exception as save_e:
|
| 424 |
print("Warning: failed to torch.save sample:", save_e)
|
| 425 |
+
# fallback JSON for debugging (kept from your original)
|
| 426 |
try:
|
| 427 |
with open(sample_path + ".json", "w") as fjson:
|
| 428 |
json.dump(sample, fjson)
|
|
|
|
|
|
|
| 429 |
except Exception:
|
| 430 |
pass
|
| 431 |
|
| 432 |
sample_idx += 1
|
| 433 |
+
rows_written += 1
|
| 434 |
|
| 435 |
+
if rows_written >= target_rows:
|
|
|
|
| 436 |
break
|
| 437 |
|
| 438 |
+
print(f"Wrote {sample_idx} sample files to {preproc_dir}.")
|
| 439 |
+
return [str(p) for p in sorted(Path(preproc_dir).glob("sample_*.pt"))]
|
| 440 |
|
|
|
|
| 441 |
|
| 442 |
+
# =============================================================================
|
| 443 |
+
# Dataset + collate
|
| 444 |
+
# =============================================================================
|
|
|
|
|
|
|
| 445 |
|
|
|
|
|
|
|
|
|
|
| 446 |
class LazyMultimodalDataset(Dataset):
|
| 447 |
+
"""
|
| 448 |
+
Lazily loads per-sample files from disk and converts them into tensors.
|
| 449 |
+
|
| 450 |
+
Each sample file is expected to contain:
|
| 451 |
+
- gine: dict or None
|
| 452 |
+
- schnet: dict or None
|
| 453 |
+
- fp: list[int] or tensor
|
| 454 |
+
- psmiles_raw: str
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
def __init__(self, sample_file_list: List[str], tokenizer, fp_length: int = FP_LENGTH, psmiles_max_len: int = PSMILES_MAX_LEN):
|
| 458 |
self.files = sample_file_list
|
| 459 |
self.tokenizer = tokenizer
|
| 460 |
self.fp_length = fp_length
|
| 461 |
self.psmiles_max_len = psmiles_max_len
|
| 462 |
|
| 463 |
+
def __len__(self) -> int:
|
| 464 |
return len(self.files)
|
| 465 |
|
| 466 |
+
def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 467 |
sample_path = self.files[idx]
|
| 468 |
+
|
| 469 |
+
# prefer torch.load if .pt, else try json (kept behavior)
|
| 470 |
if sample_path.endswith(".pt"):
|
| 471 |
sample = torch.load(sample_path, map_location="cpu")
|
| 472 |
else:
|
|
|
|
| 473 |
with open(sample_path, "r") as f:
|
| 474 |
sample = json.load(f)
|
| 475 |
|
| 476 |
+
# ---- GINE tensors ----
|
| 477 |
gine_raw = sample.get("gine", None)
|
|
|
|
| 478 |
if gine_raw:
|
| 479 |
node_atomic = torch.tensor(gine_raw.get("node_atomic", []), dtype=torch.long)
|
| 480 |
node_chirality = torch.tensor(gine_raw.get("node_chirality", []), dtype=torch.float)
|
| 481 |
node_charge = torch.tensor(gine_raw.get("node_charge", []), dtype=torch.float)
|
| 482 |
+
|
| 483 |
if gine_raw.get("edge_index", None) is not None:
|
| 484 |
+
edge_index = torch.tensor(gine_raw["edge_index"], dtype=torch.long)
|
|
|
|
| 485 |
else:
|
| 486 |
edge_index = torch.tensor([[], []], dtype=torch.long)
|
| 487 |
+
|
| 488 |
ea_raw = gine_raw.get("edge_attr", None)
|
| 489 |
if ea_raw:
|
| 490 |
edge_attr = torch.tensor(ea_raw, dtype=torch.float)
|
| 491 |
else:
|
| 492 |
edge_attr = torch.zeros((edge_index.size(1), 3), dtype=torch.float)
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
+
gine_item = {
|
| 495 |
+
"z": node_atomic,
|
| 496 |
+
"chirality": node_chirality,
|
| 497 |
+
"formal_charge": node_charge,
|
| 498 |
+
"edge_index": edge_index,
|
| 499 |
+
"edge_attr": edge_attr,
|
| 500 |
+
}
|
| 501 |
+
else:
|
| 502 |
+
gine_item = {
|
| 503 |
+
"z": torch.tensor([], dtype=torch.long),
|
| 504 |
+
"chirality": torch.tensor([], dtype=torch.float),
|
| 505 |
+
"formal_charge": torch.tensor([], dtype=torch.float),
|
| 506 |
+
"edge_index": torch.tensor([[], []], dtype=torch.long),
|
| 507 |
+
"edge_attr": torch.zeros((0, 3), dtype=torch.float),
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
# ---- SchNet tensors ----
|
| 511 |
schnet_raw = sample.get("schnet", None)
|
| 512 |
if schnet_raw:
|
| 513 |
s_z = torch.tensor(schnet_raw.get("atomic", []), dtype=torch.long)
|
|
|
|
| 516 |
else:
|
| 517 |
schnet_item = {"z": torch.tensor([], dtype=torch.long), "pos": torch.tensor([], dtype=torch.float)}
|
| 518 |
|
| 519 |
+
# ---- Fingerprint tensors ----
|
| 520 |
fp_raw = sample.get("fp", None)
|
| 521 |
if fp_raw is None:
|
| 522 |
fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
|
| 523 |
else:
|
|
|
|
| 524 |
if isinstance(fp_raw, (list, tuple)):
|
| 525 |
arr = list(fp_raw)[:self.fp_length]
|
| 526 |
if len(arr) < self.fp_length:
|
|
|
|
| 529 |
elif isinstance(fp_raw, torch.Tensor):
|
| 530 |
fp_vec = fp_raw.clone().to(torch.long)
|
| 531 |
else:
|
|
|
|
| 532 |
fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
|
| 533 |
|
| 534 |
+
# ---- PSMILES tensors ----
|
| 535 |
+
psm_raw = sample.get("psmiles_raw", "") or ""
|
|
|
|
|
|
|
| 536 |
enc = self.tokenizer(psm_raw, truncation=True, padding="max_length", max_length=self.psmiles_max_len)
|
| 537 |
p_input_ids = torch.tensor(enc["input_ids"], dtype=torch.long)
|
| 538 |
p_attn = torch.tensor(enc["attention_mask"], dtype=torch.bool)
|
| 539 |
|
| 540 |
return {
|
| 541 |
+
"gine": {
|
| 542 |
+
"z": gine_item["z"],
|
| 543 |
+
"chirality": gine_item["chirality"],
|
| 544 |
+
"formal_charge": gine_item["formal_charge"],
|
| 545 |
+
"edge_index": gine_item["edge_index"],
|
| 546 |
+
"edge_attr": gine_item["edge_attr"],
|
| 547 |
+
"num_nodes": int(gine_item["z"].size(0)) if gine_item["z"].numel() > 0 else 0,
|
| 548 |
+
},
|
| 549 |
"schnet": {"z": schnet_item["z"], "pos": schnet_item["pos"]},
|
| 550 |
"fp": {"input_ids": fp_vec},
|
| 551 |
+
"psmiles": {"input_ids": p_input_ids, "attention_mask": p_attn},
|
| 552 |
}
|
| 553 |
|
|
|
|
|
|
|
| 554 |
|
| 555 |
+
def multimodal_collate(batch_list: List[Dict[str, Dict[str, torch.Tensor]]]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
"""
|
| 557 |
+
Collate a list of LazyMultimodalDataset samples into a single multimodal batch.
|
| 558 |
+
|
| 559 |
+
Output keys:
|
| 560 |
+
- gine: {z, chirality, formal_charge, edge_index, edge_attr, batch}
|
| 561 |
+
- schnet: {z, pos, batch}
|
| 562 |
+
- fp: {input_ids, attention_mask}
|
| 563 |
+
- psmiles: {input_ids, attention_mask}
|
| 564 |
"""
|
| 565 |
+
# ---- GINE batching ----
|
| 566 |
+
all_z, all_ch, all_fc = [], [], []
|
| 567 |
+
all_edge_index, all_edge_attr = [], []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
batch_mapping = []
|
| 569 |
node_offset = 0
|
| 570 |
+
|
| 571 |
for i, item in enumerate(batch_list):
|
| 572 |
g = item["gine"]
|
| 573 |
z = g["z"]
|
| 574 |
n = z.size(0)
|
| 575 |
+
|
| 576 |
all_z.append(z)
|
| 577 |
all_ch.append(g["chirality"])
|
| 578 |
all_fc.append(g["formal_charge"])
|
| 579 |
batch_mapping.append(torch.full((n,), i, dtype=torch.long))
|
| 580 |
+
|
| 581 |
if g["edge_index"] is not None and g["edge_index"].numel() > 0:
|
| 582 |
ei_offset = g["edge_index"] + node_offset
|
| 583 |
all_edge_index.append(ei_offset)
|
| 584 |
+
|
| 585 |
+
# REUSED helper from GINE.py
|
| 586 |
ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
|
| 587 |
all_edge_attr.append(ea)
|
| 588 |
+
|
| 589 |
node_offset += n
|
| 590 |
+
|
| 591 |
if len(all_z) == 0:
|
|
|
|
| 592 |
z_batch = torch.tensor([], dtype=torch.long)
|
| 593 |
ch_batch = torch.tensor([], dtype=torch.float)
|
| 594 |
fc_batch = torch.tensor([], dtype=torch.float)
|
| 595 |
batch_batch = torch.tensor([], dtype=torch.long)
|
| 596 |
+
edge_index_batched = torch.empty((2, 0), dtype=torch.long)
|
| 597 |
+
edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
|
| 598 |
else:
|
| 599 |
z_batch = torch.cat(all_z, dim=0)
|
| 600 |
ch_batch = torch.cat(all_ch, dim=0)
|
| 601 |
fc_batch = torch.cat(all_fc, dim=0)
|
| 602 |
batch_batch = torch.cat(batch_mapping, dim=0)
|
| 603 |
+
|
| 604 |
if len(all_edge_index) > 0:
|
| 605 |
edge_index_batched = torch.cat(all_edge_index, dim=1)
|
| 606 |
edge_attr_batched = torch.cat(all_edge_attr, dim=0)
|
| 607 |
else:
|
| 608 |
+
edge_index_batched = torch.empty((2, 0), dtype=torch.long)
|
| 609 |
+
edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
|
| 610 |
|
| 611 |
+
# ---- SchNet batching ----
|
| 612 |
+
all_sz, all_pos, schnet_batch = [], [], []
|
|
|
|
|
|
|
| 613 |
for i, item in enumerate(batch_list):
|
| 614 |
s = item["schnet"]
|
| 615 |
s_z = s["z"]
|
|
|
|
| 619 |
all_sz.append(s_z)
|
| 620 |
all_pos.append(s_pos)
|
| 621 |
schnet_batch.append(torch.full((s_z.size(0),), i, dtype=torch.long))
|
| 622 |
+
|
| 623 |
if len(all_sz) == 0:
|
| 624 |
s_z_batch = torch.tensor([], dtype=torch.long)
|
| 625 |
s_pos_batch = torch.tensor([], dtype=torch.float)
|
|
|
|
| 629 |
s_pos_batch = torch.cat(all_pos, dim=0)
|
| 630 |
s_batch_batch = torch.cat(schnet_batch, dim=0)
|
| 631 |
|
| 632 |
+
# ---- FP batching ----
|
| 633 |
+
fp_ids = torch.stack(
|
| 634 |
+
[
|
| 635 |
+
item["fp"]["input_ids"] if isinstance(item["fp"]["input_ids"], torch.Tensor)
|
| 636 |
+
else torch.tensor(item["fp"]["input_ids"], dtype=torch.long)
|
| 637 |
+
for item in batch_list
|
| 638 |
+
],
|
| 639 |
+
dim=0
|
| 640 |
+
)
|
| 641 |
fp_attn = torch.ones_like(fp_ids, dtype=torch.bool)
|
| 642 |
|
| 643 |
+
# ---- PSMILES batching ----
|
| 644 |
p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch_list], dim=0)
|
| 645 |
p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch_list], dim=0)
|
| 646 |
|
| 647 |
return {
|
| 648 |
+
"gine": {
|
| 649 |
+
"z": z_batch,
|
| 650 |
+
"chirality": ch_batch,
|
| 651 |
+
"formal_charge": fc_batch,
|
| 652 |
+
"edge_index": edge_index_batched,
|
| 653 |
+
"edge_attr": edge_attr_batched,
|
| 654 |
+
"batch": batch_batch,
|
| 655 |
+
},
|
| 656 |
"schnet": {"z": s_z_batch, "pos": s_pos_batch, "batch": s_batch_batch},
|
| 657 |
"fp": {"input_ids": fp_ids, "attention_mask": fp_attn},
|
| 658 |
+
"psmiles": {"input_ids": p_ids, "attention_mask": p_attn},
|
| 659 |
}
|
| 660 |
|
|
|
|
|
|
|
| 661 |
|
| 662 |
+
def build_dataloaders(
|
| 663 |
+
sample_files: List[str],
|
| 664 |
+
tokenizer,
|
| 665 |
+
train_batch_size: int,
|
| 666 |
+
eval_batch_size: int,
|
| 667 |
+
seed: int = 42
|
| 668 |
+
) -> Tuple[DataLoader, DataLoader, torch.utils.data.Subset, torch.utils.data.Subset]:
|
| 669 |
+
"""
|
| 670 |
+
Create train/val subsets and corresponding DataLoaders.
|
| 671 |
+
"""
|
| 672 |
+
dataset = LazyMultimodalDataset(sample_files, tokenizer, fp_length=FP_LENGTH, psmiles_max_len=PSMILES_MAX_LEN)
|
| 673 |
+
|
| 674 |
+
train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=seed)
|
| 675 |
+
train_subset = torch.utils.data.Subset(dataset, train_idx)
|
| 676 |
+
val_subset = torch.utils.data.Subset(dataset, val_idx)
|
| 677 |
+
|
| 678 |
+
train_loader = DataLoader(
|
| 679 |
+
train_subset,
|
| 680 |
+
batch_size=train_batch_size,
|
| 681 |
+
shuffle=True,
|
| 682 |
+
collate_fn=multimodal_collate,
|
| 683 |
+
num_workers=0,
|
| 684 |
+
drop_last=False,
|
| 685 |
+
)
|
| 686 |
+
val_loader = DataLoader(
|
| 687 |
+
val_subset,
|
| 688 |
+
batch_size=eval_batch_size,
|
| 689 |
+
shuffle=False,
|
| 690 |
+
collate_fn=multimodal_collate,
|
| 691 |
+
num_workers=0,
|
| 692 |
+
drop_last=False,
|
| 693 |
+
)
|
| 694 |
+
return train_loader, val_loader, train_subset, val_subset
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
# =============================================================================
|
| 698 |
+
# Multimodal contrastive model
|
| 699 |
+
# =============================================================================
|
| 700 |
+
|
| 701 |
class MultimodalContrastiveModel(nn.Module):
|
| 702 |
+
"""
|
| 703 |
+
Wraps unimodal encoders and computes:
|
| 704 |
+
- InfoNCE between masked modality embedding vs mean anchor of other modalities
|
| 705 |
+
- Optional reconstruction losses for masked tokens/atoms when labels are present
|
| 706 |
+
"""
|
| 707 |
+
|
| 708 |
+
def __init__(
|
| 709 |
+
self,
|
| 710 |
+
gine_encoder: Optional[GineEncoder],
|
| 711 |
+
schnet_encoder: Optional[NodeSchNetWrapper],
|
| 712 |
+
fp_encoder: Optional[FingerprintEncoder],
|
| 713 |
+
psmiles_encoder: Optional[PSMILESDebertaEncoder],
|
| 714 |
+
emb_dim: int = 600,
|
| 715 |
+
):
|
| 716 |
super().__init__()
|
| 717 |
self.gine = gine_encoder
|
| 718 |
self.schnet = schnet_encoder
|
| 719 |
self.fp = fp_encoder
|
| 720 |
self.psmiles = psmiles_encoder
|
| 721 |
+
|
| 722 |
self.proj_gine = nn.Linear(getattr(self.gine, "pool_proj").out_features if self.gine is not None else emb_dim, emb_dim) if self.gine is not None else None
|
| 723 |
self.proj_schnet = nn.Linear(getattr(self.schnet, "pool_proj").out_features if self.schnet is not None else emb_dim, emb_dim) if self.schnet is not None else None
|
| 724 |
self.proj_fp = nn.Linear(getattr(self.fp, "pool_proj").out_features if self.fp is not None else emb_dim, emb_dim) if self.fp is not None else None
|
| 725 |
self.proj_psmiles = nn.Linear(getattr(self.psmiles, "pool_proj").out_features if self.psmiles is not None else emb_dim, emb_dim) if self.psmiles is not None else None
|
| 726 |
+
|
| 727 |
self.temperature = TEMPERATURE
|
| 728 |
+
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
|
| 729 |
|
| 730 |
def encode(self, batch_mods: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 731 |
+
"""Compute normalized projected embeddings for available modalities."""
|
| 732 |
embs = {}
|
| 733 |
+
|
| 734 |
+
if "gine" in batch_mods and self.gine is not None:
|
| 735 |
+
g = batch_mods["gine"]
|
| 736 |
+
emb_g = self.gine(g["z"], g["chirality"], g["formal_charge"], g["edge_index"], g["edge_attr"], g.get("batch", None))
|
| 737 |
+
embs["gine"] = F.normalize(self.proj_gine(emb_g), dim=-1)
|
| 738 |
+
|
| 739 |
+
if "schnet" in batch_mods and self.schnet is not None:
|
| 740 |
+
s = batch_mods["schnet"]
|
| 741 |
+
emb_s = self.schnet(s["z"], s["pos"], s.get("batch", None))
|
| 742 |
+
embs["schnet"] = F.normalize(self.proj_schnet(emb_s), dim=-1)
|
| 743 |
+
|
| 744 |
+
if "fp" in batch_mods and self.fp is not None:
|
| 745 |
+
f = batch_mods["fp"]
|
| 746 |
+
emb_f = self.fp(f["input_ids"], f.get("attention_mask", None))
|
| 747 |
+
embs["fp"] = F.normalize(self.proj_fp(emb_f), dim=-1)
|
| 748 |
+
|
| 749 |
+
if "psmiles" in batch_mods and self.psmiles is not None:
|
| 750 |
+
p = batch_mods["psmiles"]
|
| 751 |
+
emb_p = self.psmiles(p["input_ids"], p.get("attention_mask", None))
|
| 752 |
+
embs["psmiles"] = F.normalize(self.proj_psmiles(emb_p), dim=-1)
|
| 753 |
+
|
| 754 |
return embs
|
| 755 |
|
| 756 |
def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
|
| 757 |
+
"""
|
| 758 |
+
Compute total loss = InfoNCE + REC_LOSS_WEIGHT * reconstruction_loss (if any labels exist).
|
| 759 |
+
"""
|
| 760 |
device = next(self.parameters()).device
|
| 761 |
embs = self.encode(batch_mods)
|
| 762 |
info = {}
|
| 763 |
+
|
| 764 |
if mask_target not in embs:
|
| 765 |
return torch.tensor(0.0, device=device), {"batch_size": 0}
|
| 766 |
+
|
| 767 |
target = embs[mask_target]
|
| 768 |
other_keys = [k for k in embs.keys() if k != mask_target]
|
| 769 |
if len(other_keys) == 0:
|
| 770 |
return torch.tensor(0.0, device=device), {"batch_size": target.size(0)}
|
| 771 |
+
|
| 772 |
anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
|
| 773 |
logits = torch.matmul(anchor, target.T) / self.temperature
|
| 774 |
B = logits.size(0)
|
| 775 |
labels = torch.arange(B, device=logits.device)
|
| 776 |
info_nce_loss = F.cross_entropy(logits, labels)
|
| 777 |
+
info["info_nce_loss"] = float(info_nce_loss.detach().cpu().item())
|
| 778 |
|
| 779 |
+
# Optional reconstruction terms
|
| 780 |
rec_losses = []
|
| 781 |
rec_details = {}
|
| 782 |
|
| 783 |
+
# GINE node reconstruction (atomic ids) if labels present
|
| 784 |
try:
|
| 785 |
+
if "gine" in batch_mods and self.gine is not None:
|
| 786 |
+
gm = batch_mods["gine"]
|
| 787 |
+
labels_nodes = gm.get("labels", None)
|
| 788 |
if labels_nodes is not None:
|
| 789 |
+
node_logits = self.gine.node_logits(gm["z"], gm["chirality"], gm["formal_charge"], gm["edge_index"], gm["edge_attr"])
|
| 790 |
if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
|
| 791 |
loss_gine = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
|
| 792 |
rec_losses.append(loss_gine)
|
| 793 |
+
rec_details["gine_rec_loss"] = float(loss_gine.detach().cpu().item())
|
| 794 |
except Exception as e:
|
| 795 |
print("Warning: GINE reconstruction loss computation failed:", e)
|
| 796 |
|
| 797 |
+
# SchNet node reconstruction if labels present
|
| 798 |
try:
|
| 799 |
+
if "schnet" in batch_mods and self.schnet is not None:
|
| 800 |
+
sm = batch_mods["schnet"]
|
| 801 |
+
labels_nodes = sm.get("labels", None)
|
| 802 |
if labels_nodes is not None:
|
| 803 |
+
node_logits = self.schnet.node_logits(sm["z"], sm["pos"], sm.get("batch", None))
|
| 804 |
if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
|
| 805 |
loss_schnet = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
|
| 806 |
rec_losses.append(loss_schnet)
|
| 807 |
+
rec_details["schnet_rec_loss"] = float(loss_schnet.detach().cpu().item())
|
| 808 |
except Exception as e:
|
| 809 |
print("Warning: SchNet reconstruction loss computation failed:", e)
|
| 810 |
|
| 811 |
+
# FP token reconstruction if labels present
|
| 812 |
try:
|
| 813 |
+
if "fp" in batch_mods and self.fp is not None:
|
| 814 |
+
fm = batch_mods["fp"]
|
| 815 |
+
labels_fp = fm.get("labels", None)
|
| 816 |
if labels_fp is not None:
|
| 817 |
+
token_logits = self.fp.token_logits(fm["input_ids"], fm.get("attention_mask", None))
|
| 818 |
Bf, Lf, V = token_logits.shape
|
| 819 |
logits2 = token_logits.view(-1, V)
|
| 820 |
labels2 = labels_fp.view(-1).to(logits2.device)
|
| 821 |
loss_fp = self.ce_loss(logits2, labels2)
|
| 822 |
rec_losses.append(loss_fp)
|
| 823 |
+
rec_details["fp_rec_loss"] = float(loss_fp.detach().cpu().item())
|
| 824 |
except Exception as e:
|
| 825 |
print("Warning: FP reconstruction loss computation failed:", e)
|
| 826 |
|
| 827 |
+
# PSMILES MLM loss if labels present
|
| 828 |
try:
|
| 829 |
+
if "psmiles" in batch_mods and self.psmiles is not None:
|
| 830 |
+
pm = batch_mods["psmiles"]
|
| 831 |
+
labels_ps = pm.get("labels", None)
|
| 832 |
+
if labels_ps is not None:
|
| 833 |
+
loss_ps = self.psmiles.token_logits(pm["input_ids"], pm.get("attention_mask", None), labels=labels_ps)
|
| 834 |
if isinstance(loss_ps, torch.Tensor):
|
| 835 |
rec_losses.append(loss_ps)
|
| 836 |
+
rec_details["psmiles_mlm_loss"] = float(loss_ps.detach().cpu().item())
|
| 837 |
except Exception as e:
|
| 838 |
print("Warning: PSMILES MLM loss computation failed:", e)
|
| 839 |
|
| 840 |
if len(rec_losses) > 0:
|
| 841 |
rec_loss_total = sum(rec_losses) / len(rec_losses)
|
| 842 |
+
info["reconstruction_loss"] = float(rec_loss_total.detach().cpu().item())
|
| 843 |
total_loss = info_nce_loss + REC_LOSS_WEIGHT * rec_loss_total
|
| 844 |
+
info["total_loss"] = float(total_loss.detach().cpu().item())
|
| 845 |
info.update(rec_details)
|
| 846 |
else:
|
| 847 |
total_loss = info_nce_loss
|
| 848 |
+
info["reconstruction_loss"] = 0.0
|
| 849 |
+
info["total_loss"] = float(total_loss.detach().cpu().item())
|
| 850 |
|
| 851 |
return total_loss, info
|
| 852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
|
| 854 |
+
# =============================================================================
|
| 855 |
+
# Masking utilities
|
| 856 |
+
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
|
| 858 |
+
def mask_batch_for_modality(batch: dict, modality: str, tokenizer, p_mask: float = P_MASK) -> dict:
|
| 859 |
+
"""
|
| 860 |
+
Apply BERT-style masking to the selected modality and attach `labels`.
|
| 861 |
+
Other modalities are passed through unchanged (but cloned where mutated).
|
| 862 |
+
"""
|
| 863 |
+
b = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
|
| 865 |
+
# ---------------- GINE ----------------
|
| 866 |
+
if "gine" in batch:
|
| 867 |
+
z = batch["gine"]["z"].clone()
|
| 868 |
+
chir = batch["gine"]["chirality"].clone()
|
| 869 |
+
fc = batch["gine"]["formal_charge"].clone()
|
| 870 |
+
edge_index = batch["gine"]["edge_index"]
|
| 871 |
+
edge_attr = batch["gine"]["edge_attr"]
|
| 872 |
+
batch_map = batch["gine"].get("batch", None)
|
| 873 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 874 |
n_nodes = z.size(0)
|
| 875 |
dev = z.device
|
| 876 |
is_selected = torch.rand(n_nodes, device=dev) < p_mask
|
| 877 |
if is_selected.numel() > 0 and is_selected.all():
|
| 878 |
is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
|
| 879 |
+
|
| 880 |
labels_z = torch.full_like(z, fill_value=-100)
|
| 881 |
if is_selected.any():
|
| 882 |
sel_idx = torch.nonzero(is_selected).squeeze(-1)
|
| 883 |
if sel_idx.dim() == 0:
|
| 884 |
sel_idx = sel_idx.unsqueeze(0)
|
| 885 |
+
|
| 886 |
labels_z[is_selected] = z[is_selected]
|
| 887 |
+
rand_atomic = torch.randint(1, MAX_ATOMIC_Z + 1, (sel_idx.size(0),), dtype=torch.long, device=dev)
|
| 888 |
+
|
| 889 |
probs = torch.rand(sel_idx.size(0), device=dev)
|
| 890 |
mask_choice = probs < 0.8
|
| 891 |
rand_choice = (probs >= 0.8) & (probs < 0.9)
|
| 892 |
+
|
| 893 |
if mask_choice.any():
|
| 894 |
z[sel_idx[mask_choice]] = MASK_ATOM_ID
|
| 895 |
if rand_choice.any():
|
| 896 |
z[sel_idx[rand_choice]] = rand_atomic[rand_choice]
|
|
|
|
| 897 |
|
| 898 |
+
b["gine"] = {
|
| 899 |
+
"z": z,
|
| 900 |
+
"chirality": chir,
|
| 901 |
+
"formal_charge": fc,
|
| 902 |
+
"edge_index": edge_index,
|
| 903 |
+
"edge_attr": edge_attr,
|
| 904 |
+
"batch": batch_map,
|
| 905 |
+
"labels": labels_z,
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
# ---------------- SchNet ----------------
|
| 909 |
+
if "schnet" in batch:
|
| 910 |
+
z = batch["schnet"]["z"].clone()
|
| 911 |
+
pos = batch["schnet"]["pos"].clone()
|
| 912 |
+
batch_map = batch["schnet"].get("batch", None)
|
| 913 |
+
|
| 914 |
n_nodes = z.size(0)
|
| 915 |
dev = z.device
|
| 916 |
is_selected = torch.rand(n_nodes, device=dev) < p_mask
|
| 917 |
if is_selected.numel() > 0 and is_selected.all():
|
| 918 |
is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
|
| 919 |
+
|
| 920 |
labels_z = torch.full((n_nodes,), -100, dtype=torch.long, device=dev)
|
| 921 |
if is_selected.any():
|
| 922 |
sel_idx = torch.nonzero(is_selected).squeeze(-1)
|
| 923 |
if sel_idx.dim() == 0:
|
| 924 |
sel_idx = sel_idx.unsqueeze(0)
|
| 925 |
+
|
| 926 |
labels_z[is_selected] = z[is_selected]
|
| 927 |
probs_c = torch.rand(sel_idx.size(0), device=dev)
|
| 928 |
noisy_choice = probs_c < 0.8
|
| 929 |
randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9)
|
| 930 |
+
|
| 931 |
if noisy_choice.any():
|
| 932 |
idx = sel_idx[noisy_choice]
|
| 933 |
noise = torch.randn((idx.size(0), 3), device=pos.device) * 0.5
|
| 934 |
pos[idx] = pos[idx] + noise
|
| 935 |
+
|
| 936 |
if randpos_choice.any():
|
| 937 |
idx = sel_idx[randpos_choice]
|
| 938 |
mins = pos.min(dim=0).values
|
| 939 |
maxs = pos.max(dim=0).values
|
| 940 |
randpos = (torch.rand((idx.size(0), 3), device=pos.device) * (maxs - mins)) + mins
|
| 941 |
pos[idx] = randpos
|
|
|
|
| 942 |
|
| 943 |
+
b["schnet"] = {"z": z, "pos": pos, "batch": batch_map, "labels": labels_z}
|
| 944 |
+
|
| 945 |
+
# ---------------- FP ----------------
|
| 946 |
+
if "fp" in batch:
|
| 947 |
+
input_ids = batch["fp"]["input_ids"].clone()
|
| 948 |
+
attn = batch["fp"].get("attention_mask", torch.ones_like(input_ids, dtype=torch.bool))
|
| 949 |
+
|
| 950 |
B, L = input_ids.shape
|
| 951 |
dev = input_ids.device
|
| 952 |
labels_z = torch.full_like(input_ids, -100)
|
| 953 |
+
|
| 954 |
for i in range(B):
|
| 955 |
sel = torch.rand(L, device=dev) < p_mask
|
| 956 |
if sel.numel() > 0 and sel.all():
|
| 957 |
sel[torch.randint(0, L, (1,), device=dev)] = False
|
| 958 |
+
|
| 959 |
sel_idx = torch.nonzero(sel).squeeze(-1)
|
| 960 |
if sel_idx.numel() > 0:
|
| 961 |
if sel_idx.dim() == 0:
|
| 962 |
sel_idx = sel_idx.unsqueeze(0)
|
| 963 |
+
|
| 964 |
labels_z[i, sel_idx] = input_ids[i, sel_idx]
|
| 965 |
+
|
| 966 |
probs = torch.rand(sel_idx.size(0), device=dev)
|
| 967 |
mask_choice = probs < 0.8
|
| 968 |
rand_choice = (probs >= 0.8) & (probs < 0.9)
|
| 969 |
+
|
| 970 |
if mask_choice.any():
|
| 971 |
input_ids[i, sel_idx[mask_choice]] = MASK_TOKEN_ID_FP
|
| 972 |
if rand_choice.any():
|
| 973 |
rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long, device=dev)
|
| 974 |
input_ids[i, sel_idx[rand_choice]] = rand_bits
|
|
|
|
| 975 |
|
| 976 |
+
b["fp"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
|
| 977 |
+
|
| 978 |
+
# ---------------- PSMILES ----------------
|
| 979 |
+
if "psmiles" in batch:
|
| 980 |
+
input_ids = batch["psmiles"]["input_ids"].clone()
|
| 981 |
+
attn = batch["psmiles"]["attention_mask"].clone()
|
| 982 |
+
|
| 983 |
B, L = input_ids.shape
|
| 984 |
dev = input_ids.device
|
| 985 |
labels_z = torch.full_like(input_ids, -100)
|
| 986 |
+
|
| 987 |
+
# If tokenizer is unavailable, keep labels=-100 (no MLM loss)
|
| 988 |
if tokenizer is None:
|
| 989 |
+
b["psmiles"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
|
| 990 |
else:
|
| 991 |
mask_token_id = tokenizer.mask_token_id if getattr(tokenizer, "mask_token_id", None) is not None else getattr(tokenizer, "vocab", {}).get("<mask>", 1)
|
| 992 |
+
|
| 993 |
for i in range(B):
|
| 994 |
sel = torch.rand(L, device=dev) < p_mask
|
| 995 |
if sel.numel() > 0 and sel.all():
|
| 996 |
sel[torch.randint(0, L, (1,), device=dev)] = False
|
| 997 |
+
|
| 998 |
sel_idx = torch.nonzero(sel).squeeze(-1)
|
| 999 |
if sel_idx.numel() > 0:
|
| 1000 |
if sel_idx.dim() == 0:
|
| 1001 |
sel_idx = sel_idx.unsqueeze(0)
|
| 1002 |
+
|
| 1003 |
labels_z[i, sel_idx] = input_ids[i, sel_idx]
|
| 1004 |
+
|
| 1005 |
probs = torch.rand(sel_idx.size(0), device=dev)
|
| 1006 |
mask_choice = probs < 0.8
|
| 1007 |
rand_choice = (probs >= 0.8) & (probs < 0.9)
|
| 1008 |
+
|
| 1009 |
if mask_choice.any():
|
| 1010 |
input_ids[i, sel_idx[mask_choice]] = mask_token_id
|
| 1011 |
if rand_choice.any():
|
| 1012 |
rand_ids = torch.randint(0, getattr(tokenizer, "vocab_size", 300), (rand_choice.sum().item(),), dtype=torch.long, device=dev)
|
| 1013 |
input_ids[i, sel_idx[rand_choice]] = rand_ids
|
| 1014 |
+
|
| 1015 |
+
b["psmiles"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
|
| 1016 |
|
| 1017 |
return b
|
| 1018 |
|
| 1019 |
+
|
| 1020 |
+
def mm_batch_to_model_input(masked_batch: dict) -> dict:
|
| 1021 |
+
"""
|
| 1022 |
+
Normalize the masked batch dict into the exact structure expected by MultimodalContrastiveModel.
|
| 1023 |
+
(Kept identical semantics.)
|
| 1024 |
+
"""
|
| 1025 |
mm = {}
|
| 1026 |
+
if "gine" in masked_batch:
|
| 1027 |
+
gm = masked_batch["gine"]
|
| 1028 |
+
mm["gine"] = {
|
| 1029 |
+
"z": gm["z"],
|
| 1030 |
+
"chirality": gm["chirality"],
|
| 1031 |
+
"formal_charge": gm["formal_charge"],
|
| 1032 |
+
"edge_index": gm["edge_index"],
|
| 1033 |
+
"edge_attr": gm["edge_attr"],
|
| 1034 |
+
"batch": gm.get("batch", None),
|
| 1035 |
+
"labels": gm.get("labels", None),
|
| 1036 |
+
}
|
| 1037 |
+
if "schnet" in masked_batch:
|
| 1038 |
+
sm = masked_batch["schnet"]
|
| 1039 |
+
mm["schnet"] = {"z": sm["z"], "pos": sm["pos"], "batch": sm.get("batch", None), "labels": sm.get("labels", None)}
|
| 1040 |
+
if "fp" in masked_batch:
|
| 1041 |
+
fm = masked_batch["fp"]
|
| 1042 |
+
mm["fp"] = {"input_ids": fm["input_ids"], "attention_mask": fm.get("attention_mask", None), "labels": fm.get("labels", None)}
|
| 1043 |
+
if "psmiles" in masked_batch:
|
| 1044 |
+
pm = masked_batch["psmiles"]
|
| 1045 |
+
mm["psmiles"] = {"input_ids": pm["input_ids"], "attention_mask": pm.get("attention_mask", None), "labels": pm.get("labels", None)}
|
| 1046 |
return mm
|
| 1047 |
|
| 1048 |
+
|
| 1049 |
+
# =============================================================================
|
| 1050 |
+
# Evaluation
|
| 1051 |
+
# =============================================================================
|
| 1052 |
+
|
| 1053 |
+
def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader: DataLoader, device: torch.device, tokenizer, mask_target: str = "fp") -> Dict[str, float]:
|
| 1054 |
+
"""
|
| 1055 |
+
Contrastive-only evaluation:
|
| 1056 |
+
- masks one modality
|
| 1057 |
+
- computes InfoNCE logits = anchor·target / T
|
| 1058 |
+
- reports eval_loss, top1 acc, weighted F1
|
| 1059 |
+
"""
|
| 1060 |
model.eval()
|
| 1061 |
total_loss = 0.0
|
| 1062 |
total_examples = 0
|
| 1063 |
acc_sum = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
f1_sum = 0.0
|
| 1065 |
|
| 1066 |
with torch.no_grad():
|
| 1067 |
for batch in val_loader:
|
| 1068 |
+
masked_batch = mask_batch_for_modality(batch, mask_target, tokenizer=tokenizer, p_mask=P_MASK)
|
| 1069 |
+
|
| 1070 |
+
# Move tensors to device
|
| 1071 |
for k in masked_batch:
|
| 1072 |
for subk in masked_batch[k]:
|
| 1073 |
if isinstance(masked_batch[k][subk], torch.Tensor):
|
| 1074 |
masked_batch[k][subk] = masked_batch[k][subk].to(device)
|
| 1075 |
+
|
| 1076 |
mm_in = mm_batch_to_model_input(masked_batch)
|
| 1077 |
embs = model.encode(mm_in)
|
| 1078 |
+
|
| 1079 |
if mask_target not in embs:
|
| 1080 |
continue
|
| 1081 |
+
|
| 1082 |
target = embs[mask_target]
|
| 1083 |
other_keys = [k for k in embs.keys() if k != mask_target]
|
| 1084 |
if len(other_keys) == 0:
|
| 1085 |
continue
|
| 1086 |
+
|
| 1087 |
anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
|
| 1088 |
logits = torch.matmul(anchor, target.T) / model.temperature
|
| 1089 |
+
|
| 1090 |
B = logits.size(0)
|
| 1091 |
labels = torch.arange(B, device=logits.device)
|
| 1092 |
+
|
| 1093 |
loss = F.cross_entropy(logits, labels)
|
| 1094 |
total_loss += loss.item() * B
|
| 1095 |
total_examples += B
|
|
|
|
| 1098 |
acc = (preds == labels).float().mean().item()
|
| 1099 |
acc_sum += acc * B
|
| 1100 |
|
| 1101 |
+
# Weighted F1 over instance IDs (kept as in your prior logic)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
try:
|
| 1103 |
labels_np = labels.cpu().numpy()
|
| 1104 |
preds_np = preds.cpu().numpy()
|
| 1105 |
if len(np.unique(labels_np)) < 2:
|
| 1106 |
batch_f1 = float(acc)
|
| 1107 |
else:
|
| 1108 |
+
batch_f1 = f1_score(labels_np, preds_np, average="weighted")
|
| 1109 |
except Exception:
|
| 1110 |
batch_f1 = float(acc)
|
| 1111 |
f1_sum += batch_f1 * B
|
|
|
|
| 1113 |
if total_examples == 0:
|
| 1114 |
return {"eval_loss": float("nan"), "eval_accuracy": 0.0, "eval_f1_weighted": 0.0}
|
| 1115 |
|
| 1116 |
+
return {
|
| 1117 |
+
"eval_loss": total_loss / total_examples,
|
| 1118 |
+
"eval_accuracy": acc_sum / total_examples,
|
| 1119 |
+
"eval_f1_weighted": f1_sum / total_examples,
|
| 1120 |
+
}
|
| 1121 |
+
|
| 1122 |
|
| 1123 |
+
# =============================================================================
|
| 1124 |
+
# HF wrapper + collator + trainer
|
| 1125 |
+
# =============================================================================
|
| 1126 |
|
|
|
|
|
|
|
| 1127 |
class HFMultimodalModule(nn.Module):
|
| 1128 |
+
"""
|
| 1129 |
+
HuggingFace Trainer-facing wrapper:
|
| 1130 |
+
- Receives a full multimodal batch
|
| 1131 |
+
- Randomly masks one modality (provided by collator) inside forward
|
| 1132 |
+
- Returns a dict compatible with Trainer (loss, logits, labels)
|
| 1133 |
+
"""
|
| 1134 |
+
|
| 1135 |
+
def __init__(self, mm_model: MultimodalContrastiveModel, tokenizer):
|
| 1136 |
super().__init__()
|
| 1137 |
self.mm = mm_model
|
| 1138 |
+
self._tokenizer = tokenizer
|
| 1139 |
|
| 1140 |
def forward(self, **kwargs):
|
| 1141 |
if "batch" in kwargs:
|
|
|
|
| 1148 |
batch = {k: found[k] for k in found}
|
| 1149 |
mask_target = kwargs.get("mask_target", "fp")
|
| 1150 |
else:
|
| 1151 |
+
raise ValueError(
|
| 1152 |
+
"HFMultimodalModule.forward could not find 'batch' nor modality keys in inputs. "
|
| 1153 |
+
f"Inputs keys: {list(kwargs.keys())}"
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
masked_batch = mask_batch_for_modality(batch, mask_target, tokenizer=self._tokenizer, p_mask=P_MASK)
|
| 1157 |
+
|
| 1158 |
device = next(self.parameters()).device
|
| 1159 |
for k in masked_batch:
|
| 1160 |
for subk in list(masked_batch[k].keys()):
|
| 1161 |
val = masked_batch[k][subk]
|
| 1162 |
if isinstance(val, torch.Tensor):
|
| 1163 |
masked_batch[k][subk] = val.to(device)
|
| 1164 |
+
|
| 1165 |
mm_in = mm_batch_to_model_input(masked_batch)
|
| 1166 |
loss, info = self.mm(mm_in, mask_target)
|
| 1167 |
+
|
| 1168 |
logits = None
|
| 1169 |
labels = None
|
| 1170 |
try:
|
|
|
|
| 1182 |
print("Warning: failed to compute logits/labels inside HFMultimodalModule.forward:", e)
|
| 1183 |
logits = None
|
| 1184 |
labels = None
|
| 1185 |
+
|
| 1186 |
eval_loss = loss.detach() if isinstance(loss, torch.Tensor) else torch.tensor(float(loss), device=device)
|
| 1187 |
out = {"loss": loss, "eval_loss": eval_loss}
|
| 1188 |
if logits is not None:
|
|
|
|
| 1192 |
out["mm_info"] = info
|
| 1193 |
return out
|
| 1194 |
|
|
|
|
|
|
|
| 1195 |
|
| 1196 |
class ContrastiveDataCollator:
|
| 1197 |
+
"""
|
| 1198 |
+
Collator used by Trainer:
|
| 1199 |
+
- If given raw samples (list of dicts), it calls multimodal_collate
|
| 1200 |
+
- Then selects a random modality to mask (mask_target)
|
| 1201 |
+
"""
|
| 1202 |
+
|
| 1203 |
+
def __init__(self, mask_prob: float = P_MASK, modalities: Optional[List[str]] = None):
|
| 1204 |
self.mask_prob = mask_prob
|
| 1205 |
self.modalities = modalities if modalities is not None else ["gine", "schnet", "fp", "psmiles"]
|
| 1206 |
|
|
|
|
| 1209 |
collated = features
|
| 1210 |
mask_target = random.choice([m for m in self.modalities if m in collated])
|
| 1211 |
return {"batch": collated, "mask_target": mask_target}
|
| 1212 |
+
|
| 1213 |
if isinstance(features, (list, tuple)) and len(features) > 0:
|
| 1214 |
first = features[0]
|
| 1215 |
+
if isinstance(first, dict) and "gine" in first:
|
| 1216 |
collated = multimodal_collate(list(features))
|
| 1217 |
mask_target = random.choice([m for m in self.modalities if m in collated])
|
| 1218 |
return {"batch": collated, "mask_target": mask_target}
|
| 1219 |
+
|
| 1220 |
+
if isinstance(first, dict) and "batch" in first:
|
| 1221 |
+
collated = first["batch"]
|
| 1222 |
mask_target = first.get("mask_target", random.choice([m for m in self.modalities if m in collated]))
|
| 1223 |
return {"batch": collated, "mask_target": mask_target}
|
| 1224 |
+
|
| 1225 |
print("ContrastiveDataCollator received unexpected 'features' shape/type.")
|
| 1226 |
raise ValueError("ContrastiveDataCollator could not collate input. Expected list[dict] with 'gine' key or already-collated dict.")
|
| 1227 |
|
|
|
|
| 1228 |
|
| 1229 |
class VerboseTrainingCallback(TrainerCallback):
|
| 1230 |
+
"""
|
| 1231 |
+
Console-first training callback with early stopping on eval_loss.
|
| 1232 |
+
|
| 1233 |
+
Behavior is kept consistent with your original callback; changes are comment/structure only.
|
| 1234 |
+
"""
|
| 1235 |
+
|
| 1236 |
def __init__(self, patience: int = 10):
|
| 1237 |
self.start_time = time.time()
|
| 1238 |
self.epoch_start_time = time.time()
|
|
|
|
| 1260 |
|
| 1261 |
def on_train_begin(self, args, state, control, **kwargs):
|
| 1262 |
self.start_time = time.time()
|
| 1263 |
+
print("=" * 80)
|
| 1264 |
+
print(" STARTING MULTIMODAL CONTRASTIVE LEARNING TRAINING")
|
| 1265 |
+
print("=" * 80)
|
| 1266 |
+
|
| 1267 |
+
model = kwargs.get("model")
|
| 1268 |
if model is not None:
|
| 1269 |
total_params = sum(p.numel() for p in model.parameters())
|
| 1270 |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1271 |
non_trainable_params = total_params - trainable_params
|
| 1272 |
+
print(" MODEL PARAMETERS:")
|
| 1273 |
print(f" Total Parameters: {total_params:,}")
|
| 1274 |
print(f" Trainable Parameters: {trainable_params:,}")
|
| 1275 |
print(f" Non-trainable Parameters: {non_trainable_params:,}")
|
| 1276 |
print(f" Training Progress: 0/{args.num_train_epochs} epochs")
|
| 1277 |
+
|
| 1278 |
+
print("=" * 80)
|
| 1279 |
|
| 1280 |
def on_epoch_begin(self, args, state, control, **kwargs):
|
| 1281 |
self.epoch_start_time = time.time()
|
| 1282 |
current_epoch = state.epoch if state is not None else 0.0
|
| 1283 |
+
print(f" Epoch {current_epoch + 1:.1f}/{args.num_train_epochs} Starting...")
|
| 1284 |
|
| 1285 |
def on_epoch_end(self, args, state, control, **kwargs):
|
| 1286 |
train_loss = None
|
| 1287 |
for log in reversed(state.log_history):
|
| 1288 |
+
if isinstance(log, dict) and "loss" in log and float(log.get("loss", 0)) != 0.0:
|
| 1289 |
+
train_loss = log["loss"]
|
| 1290 |
break
|
| 1291 |
self._last_train_loss = train_loss
|
| 1292 |
|
| 1293 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 1294 |
+
if logs is not None and "loss" in logs:
|
| 1295 |
current_step = state.global_step
|
| 1296 |
current_epoch = state.epoch
|
| 1297 |
try:
|
| 1298 |
steps_per_epoch = max(1, len(train_loader) // args.gradient_accumulation_steps)
|
| 1299 |
except Exception:
|
| 1300 |
steps_per_epoch = 1
|
| 1301 |
+
|
| 1302 |
if current_step % max(1, steps_per_epoch // 10) == 0:
|
| 1303 |
progress = current_epoch + (current_step % steps_per_epoch) / steps_per_epoch
|
| 1304 |
print(f" Step {current_step:4d} | Epoch {progress:.1f} | Train Loss: {logs['loss']:.6f}")
|
|
|
|
| 1306 |
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 1307 |
current_epoch = state.epoch if state is not None else 0.0
|
| 1308 |
epoch_time = time.time() - self.epoch_start_time
|
| 1309 |
+
|
| 1310 |
+
hf_metrics = metrics if metrics is not None else kwargs.get("metrics", None)
|
| 1311 |
hf_eval_loss = None
|
| 1312 |
hf_train_loss = self._last_train_loss
|
| 1313 |
+
|
| 1314 |
if hf_metrics is not None:
|
| 1315 |
+
hf_eval_loss = hf_metrics.get("eval_loss", hf_metrics.get("loss", None))
|
| 1316 |
if hf_train_loss is None:
|
| 1317 |
+
hf_train_loss = hf_metrics.get("train_loss", hf_train_loss)
|
| 1318 |
+
|
| 1319 |
cl_metrics = {}
|
| 1320 |
try:
|
| 1321 |
+
model = kwargs.get("model", None)
|
| 1322 |
if model is not None:
|
| 1323 |
cl_model = model.mm if hasattr(model, "mm") else model
|
| 1324 |
+
cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp")
|
| 1325 |
else:
|
| 1326 |
+
cl_metrics = evaluate_multimodal(multimodal_model, val_loader, device, tokenizer, mask_target="fp")
|
| 1327 |
except Exception as e:
|
| 1328 |
print("Warning: evaluate_multimodal inside callback failed:", e)
|
| 1329 |
+
|
| 1330 |
if hf_eval_loss is None:
|
| 1331 |
+
hf_eval_loss = cl_metrics.get("eval_loss", None)
|
| 1332 |
+
|
| 1333 |
+
val_acc = cl_metrics.get("eval_accuracy", "N/A")
|
| 1334 |
+
val_f1 = cl_metrics.get("eval_f1_weighted", "N/A")
|
| 1335 |
+
|
| 1336 |
+
print(f" EPOCH {current_epoch + 1:.1f} RESULTS:")
|
| 1337 |
if hf_train_loss is not None:
|
| 1338 |
try:
|
| 1339 |
print(f" Train Loss (HF reported): {hf_train_loss:.6f}")
|
| 1340 |
except Exception:
|
| 1341 |
print(f" Train Loss (HF reported): {hf_train_loss}")
|
| 1342 |
else:
|
| 1343 |
+
print(" Train Loss (HF reported): N/A")
|
| 1344 |
+
|
| 1345 |
if hf_eval_loss is not None:
|
| 1346 |
try:
|
| 1347 |
print(f" Eval Loss (HF reported): {hf_eval_loss:.6f}")
|
| 1348 |
except Exception:
|
| 1349 |
print(f" Eval Loss (HF reported): {hf_eval_loss}")
|
| 1350 |
else:
|
| 1351 |
+
print(" Eval Loss (HF reported): N/A")
|
| 1352 |
+
|
| 1353 |
if isinstance(val_acc, float):
|
| 1354 |
print(f" Eval Acc (CL evaluator): {val_acc:.6f}")
|
| 1355 |
else:
|
| 1356 |
print(f" Eval Acc (CL evaluator): {val_acc}")
|
| 1357 |
+
|
| 1358 |
if isinstance(val_f1, float):
|
| 1359 |
print(f" Eval F1 Weighted (CL evaluator): {val_f1:.6f}")
|
| 1360 |
else:
|
| 1361 |
print(f" Eval F1 Weighted (CL evaluator): {val_f1}")
|
| 1362 |
+
|
| 1363 |
+
current_val = hf_eval_loss if hf_eval_loss is not None else float("inf")
|
| 1364 |
+
|
| 1365 |
if current_val < self.best_val_loss - 1e-6:
|
|
|
|
| 1366 |
self.best_val_loss = current_val
|
| 1367 |
self.best_epoch = current_epoch
|
| 1368 |
self.epochs_no_improve = 0
|
|
|
|
| 1372 |
print("Warning: saving best model failed:", e)
|
| 1373 |
else:
|
| 1374 |
self.epochs_no_improve += 1
|
| 1375 |
+
|
| 1376 |
if self.epochs_no_improve >= self.patience:
|
| 1377 |
print(f"Early stopping: no improvement in val_loss for {self.patience} epochs.")
|
| 1378 |
control.should_training_stop = True
|
| 1379 |
+
|
| 1380 |
print(f" Epoch Training Time: {epoch_time:.2f}s")
|
| 1381 |
print(f" Best Val Loss so far: {self.best_val_loss}")
|
| 1382 |
print(f" Epochs since improvement: {self.epochs_no_improve}/{self.patience}")
|
|
|
|
| 1384 |
|
| 1385 |
def on_train_end(self, args, state, control, **kwargs):
|
| 1386 |
total_time = time.time() - self.start_time
|
| 1387 |
+
print("=" * 80)
|
| 1388 |
+
print(" TRAINING COMPLETED")
|
| 1389 |
+
print("=" * 80)
|
| 1390 |
print(f" Total Training Time: {total_time:.2f}s")
|
| 1391 |
if state is not None:
|
| 1392 |
try:
|
| 1393 |
print(f" Total Epochs Completed: {state.epoch + 1:.1f}")
|
| 1394 |
except Exception:
|
| 1395 |
pass
|
| 1396 |
+
print("=" * 80)
|
| 1397 |
+
|
| 1398 |
|
| 1399 |
+
class CLTrainer(Trainer):
|
| 1400 |
+
"""
|
| 1401 |
+
Custom Trainer:
|
| 1402 |
+
- evaluate(): merges HF eval with contrastive evaluator (same behavior)
|
| 1403 |
+
- _save(): saves a state_dict under pytorch_model.bin
|
| 1404 |
+
- _load_best_model(): loads best pytorch_model.bin
|
| 1405 |
+
"""
|
| 1406 |
|
|
|
|
| 1407 |
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
|
| 1408 |
try:
|
| 1409 |
metrics = super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) or {}
|
|
|
|
| 1413 |
traceback.print_exc()
|
| 1414 |
try:
|
| 1415 |
cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
|
| 1416 |
+
cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp")
|
| 1417 |
metrics = {k: float(v) if isinstance(v, (float, int, np.floating, np.integer)) else v for k, v in cl_metrics.items()}
|
| 1418 |
metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
|
| 1419 |
except Exception as e2:
|
|
|
|
| 1421 |
traceback.print_exc()
|
| 1422 |
metrics = {"eval_loss": float("nan"), "epoch": float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else 0.0}
|
| 1423 |
return metrics
|
| 1424 |
+
|
| 1425 |
try:
|
| 1426 |
cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
|
| 1427 |
+
cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp")
|
| 1428 |
except Exception as e:
|
| 1429 |
print("Warning: evaluate_multimodal failed inside CLTrainer.evaluate():", e)
|
| 1430 |
cl_metrics = {}
|
| 1431 |
+
|
| 1432 |
for k, v in cl_metrics.items():
|
| 1433 |
try:
|
| 1434 |
metrics[k] = float(v)
|
| 1435 |
except Exception:
|
| 1436 |
metrics[k] = v
|
| 1437 |
+
|
| 1438 |
+
if "eval_loss" not in metrics and "eval_loss" in cl_metrics:
|
| 1439 |
try:
|
| 1440 |
+
metrics["eval_loss"] = float(cl_metrics["eval_loss"])
|
| 1441 |
except Exception:
|
| 1442 |
+
metrics["eval_loss"] = cl_metrics["eval_loss"]
|
| 1443 |
+
|
| 1444 |
if "epoch" not in metrics:
|
| 1445 |
metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
|
| 1446 |
+
|
| 1447 |
return metrics
|
| 1448 |
|
| 1449 |
def _save(self, output_dir: str):
|
| 1450 |
os.makedirs(output_dir, exist_ok=True)
|
| 1451 |
+
|
| 1452 |
try:
|
| 1453 |
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
| 1454 |
except Exception:
|
| 1455 |
pass
|
| 1456 |
+
|
| 1457 |
try:
|
| 1458 |
model_to_save = self.model.mm if hasattr(self.model, "mm") else self.model
|
| 1459 |
torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
|
|
|
| 1465 |
raise e
|
| 1466 |
except Exception as e2:
|
| 1467 |
print("Warning: failed to save model state_dict:", e2)
|
| 1468 |
+
|
| 1469 |
try:
|
| 1470 |
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| 1471 |
except Exception:
|
| 1472 |
pass
|
| 1473 |
+
|
| 1474 |
try:
|
| 1475 |
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| 1476 |
except Exception:
|
|
|
|
| 1480 |
best_ckpt = self.state.best_model_checkpoint
|
| 1481 |
if not best_ckpt:
|
| 1482 |
return
|
| 1483 |
+
|
| 1484 |
candidate = os.path.join(best_ckpt, "pytorch_model.bin")
|
| 1485 |
if not os.path.exists(candidate):
|
| 1486 |
candidate = os.path.join(best_ckpt, "model.bin")
|
| 1487 |
if not os.path.exists(candidate):
|
| 1488 |
candidate = None
|
| 1489 |
+
|
| 1490 |
if candidate is None:
|
| 1491 |
print(f"CLTrainer._load_best_model(): no compatible pytorch_model.bin found in {best_ckpt}; skipping load.")
|
| 1492 |
return
|
| 1493 |
+
|
| 1494 |
try:
|
| 1495 |
state_dict = torch.load(candidate, map_location=self.args.device)
|
| 1496 |
model_to_load = self.model.mm if hasattr(self.model, "mm") else self.model
|
|
|
|
| 1500 |
print("CLTrainer._load_best_model: failed to load state_dict using torch.load:", e)
|
| 1501 |
return
|
| 1502 |
|
|
|
|
| 1503 |
|
| 1504 |
+
# =============================================================================
|
| 1505 |
+
# Model construction + weight loading
|
| 1506 |
+
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1507 |
|
| 1508 |
+
def load_state_dict_if_present(model: nn.Module, ckpt_dir: str, filename: str = "pytorch_model.bin") -> None:
|
| 1509 |
+
"""Load model weights if the checkpoint file exists."""
|
| 1510 |
+
path = os.path.join(ckpt_dir, filename)
|
| 1511 |
+
if os.path.exists(path):
|
| 1512 |
+
try:
|
| 1513 |
+
model.load_state_dict(torch.load(path, map_location="cpu"), strict=False)
|
| 1514 |
+
print(f"Loaded weights from {path}")
|
| 1515 |
+
except Exception as e:
|
| 1516 |
+
print(f"Could not load weights from {path}: {e}")
|
| 1517 |
+
|
| 1518 |
+
|
| 1519 |
+
def build_models(device: torch.device) -> Tuple[MultimodalContrastiveModel, PSMILESDebertaEncoder]:
|
| 1520 |
+
"""Instantiate unimodal encoders, optionally load best checkpoints, and assemble the multimodal model."""
|
| 1521 |
+
# GINE
|
| 1522 |
+
gine_encoder = GineEncoder(node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z)
|
| 1523 |
+
load_state_dict_if_present(gine_encoder, BEST_GINE_DIR)
|
| 1524 |
+
gine_encoder.to(device)
|
| 1525 |
+
|
| 1526 |
+
# SchNet
|
| 1527 |
+
schnet_encoder = NodeSchNetWrapper(
|
| 1528 |
+
hidden_channels=SCHNET_HIDDEN,
|
| 1529 |
+
num_interactions=SCHNET_NUM_INTERACTIONS,
|
| 1530 |
+
num_gaussians=SCHNET_NUM_GAUSSIANS,
|
| 1531 |
+
cutoff=SCHNET_CUTOFF,
|
| 1532 |
+
max_num_neighbors=SCHNET_MAX_NEIGHBORS,
|
| 1533 |
+
)
|
| 1534 |
+
load_state_dict_if_present(schnet_encoder, BEST_SCHNET_DIR)
|
| 1535 |
+
schnet_encoder.to(device)
|
| 1536 |
+
|
| 1537 |
+
# Fingerprint encoder
|
| 1538 |
+
fp_encoder = FingerprintEncoder(
|
| 1539 |
+
vocab_size=VOCAB_SIZE_FP,
|
| 1540 |
+
hidden_dim=256,
|
| 1541 |
+
seq_len=FP_LENGTH,
|
| 1542 |
+
num_layers=4,
|
| 1543 |
+
nhead=8,
|
| 1544 |
+
dim_feedforward=1024,
|
| 1545 |
+
dropout=0.1,
|
| 1546 |
+
)
|
| 1547 |
+
load_state_dict_if_present(fp_encoder, BEST_FP_DIR)
|
| 1548 |
+
fp_encoder.to(device)
|
| 1549 |
+
|
| 1550 |
+
# PSMILES / DeBERTa
|
| 1551 |
+
psmiles_encoder = None
|
| 1552 |
+
if os.path.isdir(BEST_PSMILES_DIR):
|
| 1553 |
+
try:
|
| 1554 |
+
psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
|
| 1555 |
+
print("Loaded Deberta (PSMILES) from", BEST_PSMILES_DIR)
|
| 1556 |
+
except Exception as e:
|
| 1557 |
+
print("Failed to load Deberta from saved directory:", e)
|
| 1558 |
|
| 1559 |
+
if psmiles_encoder is None:
|
| 1560 |
+
psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=None)
|
| 1561 |
|
| 1562 |
+
psmiles_encoder.to(device)
|
| 1563 |
+
|
| 1564 |
+
multimodal_model = MultimodalContrastiveModel(gine_encoder, schnet_encoder, fp_encoder, psmiles_encoder, emb_dim=600)
|
| 1565 |
+
multimodal_model.to(device)
|
| 1566 |
+
|
| 1567 |
+
return multimodal_model, psmiles_encoder
|
| 1568 |
+
|
| 1569 |
+
|
| 1570 |
+
# =============================================================================
|
| 1571 |
+
# Main execution
|
| 1572 |
+
# =============================================================================
|
| 1573 |
+
|
| 1574 |
+
def main():
|
| 1575 |
+
# ---- setup ----
|
| 1576 |
+
ensure_dir(OUTPUT_DIR)
|
| 1577 |
+
ensure_dir(PREPROC_DIR)
|
| 1578 |
+
|
| 1579 |
+
device_local = get_device()
|
| 1580 |
+
print("Device:", device_local)
|
| 1581 |
+
|
| 1582 |
+
set_seed(42)
|
| 1583 |
+
|
| 1584 |
+
training_args = TrainingArguments(
|
| 1585 |
+
output_dir=OUTPUT_DIR,
|
| 1586 |
+
overwrite_output_dir=True,
|
| 1587 |
+
num_train_epochs=25,
|
| 1588 |
+
per_device_train_batch_size=16,
|
| 1589 |
+
per_device_eval_batch_size=8,
|
| 1590 |
+
gradient_accumulation_steps=4,
|
| 1591 |
+
eval_strategy="epoch",
|
| 1592 |
+
logging_steps=100,
|
| 1593 |
+
learning_rate=1e-4,
|
| 1594 |
+
weight_decay=0.01,
|
| 1595 |
+
eval_accumulation_steps=1000,
|
| 1596 |
+
fp16=torch.cuda.is_available(),
|
| 1597 |
+
save_strategy="epoch",
|
| 1598 |
+
save_steps=500,
|
| 1599 |
+
disable_tqdm=False,
|
| 1600 |
+
logging_first_step=True,
|
| 1601 |
+
report_to=[],
|
| 1602 |
+
dataloader_num_workers=0,
|
| 1603 |
+
load_best_model_at_end=True,
|
| 1604 |
+
metric_for_best_model="eval_loss",
|
| 1605 |
+
greater_is_better=False,
|
| 1606 |
+
)
|
| 1607 |
+
|
| 1608 |
+
# ---- data ----
|
| 1609 |
+
sample_files = prepare_or_load_data_streaming(
|
| 1610 |
+
csv_path=CSV_PATH,
|
| 1611 |
+
preproc_dir=PREPROC_DIR,
|
| 1612 |
+
target_rows=TARGET_ROWS,
|
| 1613 |
+
chunksize=CHUNKSIZE,
|
| 1614 |
+
)
|
| 1615 |
+
|
| 1616 |
+
tokenizer_local = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
|
| 1617 |
+
|
| 1618 |
+
global train_loader, val_loader, multimodal_model, device, tokenizer # kept for callback references (same behavior)
|
| 1619 |
+
tokenizer = tokenizer_local
|
| 1620 |
+
device = device_local
|
| 1621 |
+
|
| 1622 |
+
train_loader, val_loader, train_subset, val_subset = build_dataloaders(
|
| 1623 |
+
sample_files=sample_files,
|
| 1624 |
+
tokenizer=tokenizer_local,
|
| 1625 |
+
train_batch_size=training_args.per_device_train_batch_size,
|
| 1626 |
+
eval_batch_size=training_args.per_device_eval_batch_size,
|
| 1627 |
+
seed=42,
|
| 1628 |
+
)
|
| 1629 |
+
|
| 1630 |
+
# ---- models ----
|
| 1631 |
+
multimodal_model, _psmiles_encoder = build_models(device_local)
|
| 1632 |
+
|
| 1633 |
+
hf_model = HFMultimodalModule(multimodal_model, tokenizer_local).to(device_local)
|
| 1634 |
+
data_collator = ContrastiveDataCollator(mask_prob=P_MASK)
|
| 1635 |
+
|
| 1636 |
+
callback = VerboseTrainingCallback(patience=10)
|
| 1637 |
+
|
| 1638 |
+
trainer = CLTrainer(
|
| 1639 |
+
model=hf_model,
|
| 1640 |
+
args=training_args,
|
| 1641 |
+
train_dataset=train_subset,
|
| 1642 |
+
eval_dataset=val_subset,
|
| 1643 |
+
data_collator=data_collator,
|
| 1644 |
+
callbacks=[callback],
|
| 1645 |
+
)
|
| 1646 |
+
callback.trainer_ref = trainer
|
| 1647 |
+
|
| 1648 |
+
# Force HF Trainer to use our prebuilt PyTorch DataLoaders
|
| 1649 |
+
trainer.get_train_dataloader = lambda dataset=None: train_loader
|
| 1650 |
+
trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader
|
| 1651 |
+
|
| 1652 |
+
# Optimizer (kept as in original script)
|
| 1653 |
+
_optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
|
| 1654 |
+
|
| 1655 |
+
total_params = sum(p.numel() for p in multimodal_model.parameters())
|
| 1656 |
+
trainable_params = sum(p.numel() for p in multimodal_model.parameters() if p.requires_grad)
|
| 1657 |
+
non_trainable_params = total_params - trainable_params
|
| 1658 |
+
|
| 1659 |
+
print("\n MODEL PARAMETERS:")
|
| 1660 |
+
print(f" Total Parameters: {total_params:,}")
|
| 1661 |
+
print(f" Trainable Parameters: {trainable_params:,}")
|
| 1662 |
+
print(f" Non-trainable Parameters: {non_trainable_params:,}")
|
| 1663 |
+
|
| 1664 |
+
# Clear GPU cache
|
| 1665 |
+
if torch.cuda.is_available():
|
| 1666 |
+
try:
|
| 1667 |
+
torch.cuda.empty_cache()
|
| 1668 |
+
except Exception:
|
| 1669 |
+
pass
|
| 1670 |
|
| 1671 |
+
# ---- train ----
|
| 1672 |
+
training_start_time = time.time()
|
| 1673 |
+
trainer.train()
|
| 1674 |
+
training_end_time = time.time()
|
| 1675 |
|
| 1676 |
+
# ---- save best ----
|
| 1677 |
+
best_dir = os.path.join(OUTPUT_DIR, "best")
|
| 1678 |
+
os.makedirs(best_dir, exist_ok=True)
|
|
|
|
| 1679 |
|
| 1680 |
+
try:
|
| 1681 |
+
best_ckpt = trainer.state.best_model_checkpoint
|
| 1682 |
+
if best_ckpt:
|
| 1683 |
+
multimodal_model.load_state_dict(torch.load(os.path.join(best_ckpt, "pytorch_model.bin"), map_location=device_local), strict=False)
|
| 1684 |
+
print(f"Loaded best checkpoint from {best_ckpt} into multimodal_model for final evaluation.")
|
| 1685 |
+
torch.save(multimodal_model.state_dict(), os.path.join(best_dir, "pytorch_model.bin"))
|
| 1686 |
+
print(f" Saved best multimodal model to {os.path.join(best_dir, 'pytorch_model.bin')}")
|
| 1687 |
+
except Exception as e:
|
| 1688 |
+
print("Warning: failed to load/save best model from Trainer:", e)
|
| 1689 |
|
| 1690 |
+
# ---- final evaluation ----
|
| 1691 |
+
final_metrics = {}
|
|
|
|
| 1692 |
try:
|
| 1693 |
+
if trainer.state.best_model_checkpoint:
|
| 1694 |
+
trainer._load_best_model()
|
| 1695 |
+
final_metrics = trainer.evaluate(eval_dataset=val_subset)
|
| 1696 |
+
else:
|
| 1697 |
+
final_metrics = evaluate_multimodal(multimodal_model, val_loader, device_local, tokenizer_local, mask_target="fp")
|
| 1698 |
+
except Exception as e:
|
| 1699 |
+
print("Warning: final evaluation via trainer.evaluate failed, falling back to direct evaluate_multimodal:", e)
|
| 1700 |
+
final_metrics = evaluate_multimodal(multimodal_model, val_loader, device_local, tokenizer_local, mask_target="fp")
|
| 1701 |
|
| 1702 |
+
print("\n" + "=" * 80)
|
| 1703 |
+
print(" FINAL TRAINING RESULTS")
|
| 1704 |
+
print("=" * 80)
|
| 1705 |
+
print(f"Total Training Time: {training_end_time - training_start_time:.2f}s")
|
|
|
|
| 1706 |
|
| 1707 |
+
best_ckpt = trainer.state.best_model_checkpoint if hasattr(trainer.state, "best_model_checkpoint") else None
|
| 1708 |
+
print(f"Best Checkpoint: {best_ckpt if best_ckpt else '(none saved)'}")
|
|
|
|
| 1709 |
|
| 1710 |
+
hf_eval_loss = final_metrics.get("eval_loss", float("nan"))
|
| 1711 |
+
hf_eval_acc = final_metrics.get("eval_accuracy", 0.0)
|
| 1712 |
+
hf_eval_f1 = final_metrics.get("eval_f1_weighted", 0.0)
|
| 1713 |
+
|
| 1714 |
+
print(f"Val Loss (HF reported / trainer.evaluate): {hf_eval_loss:.4f}")
|
| 1715 |
+
print(f"Val Acc (CL evaluator): {hf_eval_acc:.4f}")
|
| 1716 |
+
print(f"Val F1 Weighted (CL evaluator): {hf_eval_f1:.4f}")
|
| 1717 |
+
print(f"Total Trainable Params: {trainable_params:,}")
|
| 1718 |
+
print(f"Total Non-trainable Params: {non_trainable_params:,}")
|
| 1719 |
+
print("=" * 80)
|
| 1720 |
+
|
| 1721 |
+
|
| 1722 |
+
if __name__ == "__main__":
|
| 1723 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|