""" PolyFusion - CL.py Multimodal contrastive pretraining script (DeBERTaV2 + GINE + SchNet + Transformer). """ import os import sys import csv import json import time import math import random import shutil from pathlib import Path from typing import List, Optional, Tuple, Dict # Increase csv field size limit safely try: csv.field_size_limit(sys.maxsize) except OverflowError: csv.field_size_limit(2**31 - 1) import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader # Shared model utilities from PolyFusion.GINE import GineEncoder, GineBlock, MaskedGINE, match_edge_attr_to_index, safe_get from PolyFusion.SchNet import NodeSchNetWrapper from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer # HF Trainer & Transformers from transformers import TrainingArguments, Trainer from transformers.trainer_callback import TrainerCallback from sklearn.model_selection import train_test_split from sklearn.metrics import f1_score # ============================================================================= # Configuration (paths are placeholders; update for your environment) # ============================================================================= P_MASK = 0.15 MAX_ATOMIC_Z = 85 MASK_ATOM_ID = MAX_ATOMIC_Z + 1 # GINE params NODE_EMB_DIM = 300 EDGE_EMB_DIM = 300 NUM_GNN_LAYERS = 5 # SchNet params SCHNET_NUM_GAUSSIANS = 50 SCHNET_NUM_INTERACTIONS = 6 SCHNET_CUTOFF = 10.0 SCHNET_MAX_NEIGHBORS = 64 SCHNET_HIDDEN = 600 # Fingerprint Transformer params FP_LENGTH = 2048 MASK_TOKEN_ID_FP = 2 VOCAB_SIZE_FP = 3 # DeBERTaV2 params DEBERTA_HIDDEN = 600 PSMILES_MAX_LEN = 128 # Contrastive params TEMPERATURE = 0.07 REC_LOSS_WEIGHT = 1.0 # Reconstruction loss weight # Data / preprocessing CSV_PATH = "/path/to/polymer_structures_unified_processed.csv" TARGET_ROWS = 2000000 CHUNKSIZE = 50000 PREPROC_DIR = "/path/to/preprocessed_samples" # Tokenizer assets SPM_MODEL = "/path/to/spm.model" # Outputs / checkpoints OUTPUT_DIR = "/path/to/multimodal_output" BEST_GINE_DIR = "/path/to/gin_output/best" BEST_SCHNET_DIR = "/path/to/schnet_output/best" BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best" BEST_PSMILES_DIR = "/path/to/polybert_output/best" # ============================================================================= # Reproducibility + device # ============================================================================= def get_device() -> torch.device: """Select CUDA if available (respects CUDA_VISIBLE_DEVICES), else CPU.""" return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") def set_seed(seed: int = 42) -> None: """Set Python/Numpy/Torch seeds for deterministic-ish behavior.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # ============================================================================= # Preprocessing (streaming to disk to avoid large memory spikes) # ============================================================================= def ensure_dir(path: str) -> None: """Create a directory if it doesn't exist.""" os.makedirs(path, exist_ok=True) def prepare_or_load_data_streaming( csv_path: str, preproc_dir: str, target_rows: int = TARGET_ROWS, chunksize: int = CHUNKSIZE ) -> List[str]: """ Prepare per-sample serialized files (torch .pt) for lazy loading. - If `preproc_dir` already contains `sample_*.pt`, reuse them. - Else stream the CSV in chunks and write `sample_{idx:08d}.pt` files. """ ensure_dir(preproc_dir) existing = sorted([p for p in Path(preproc_dir).glob("sample_*.pt")]) if len(existing) > 0: print(f"Found {len(existing)} preprocessed sample files in {preproc_dir}; reusing those (no reparse).") return [str(p) for p in existing] print("No existing per-sample preprocessed folder found. Parsing CSV chunked and writing per-sample files (streaming).") rows_written = 0 sample_idx = 0 for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize): has_graph = "graph" in chunk.columns has_geometry = "geometry" in chunk.columns has_fp = "fingerprints" in chunk.columns has_psmiles = "psmiles" in chunk.columns for i_row in range(len(chunk)): if rows_written >= target_rows: break row = chunk.iloc[i_row] # Per-row modality payloads (None if missing) gine_sample = None schnet_sample = None fp_sample = None psmiles_raw = None # -------- Graph / GINE modality -------- if has_graph: val = row.get("graph", "") try: graph_field = ( json.loads(val) if isinstance(val, str) and val.strip() != "" else (val if not isinstance(val, str) else None) ) except Exception: graph_field = None if graph_field: node_features = safe_get(graph_field, "node_features", None) if node_features: atomic_nums = [] chirality_vals = [] formal_charges = [] for nf in node_features: an = safe_get(nf, "atomic_num", None) if an is None: an = safe_get(nf, "atomic_number", 0) ch = safe_get(nf, "chirality", 0) fc = safe_get(nf, "formal_charge", 0) try: atomic_nums.append(int(an)) except Exception: atomic_nums.append(0) chirality_vals.append(float(ch)) formal_charges.append(float(fc)) edge_indices_raw = safe_get(graph_field, "edge_indices", None) edge_features_raw = safe_get(graph_field, "edge_features", None) edge_index = None edge_attr = None # Handle missing edge_indices via adjacency_matrix if edge_indices_raw is None: adj_mat = safe_get(graph_field, "adjacency_matrix", None) if adj_mat: srcs, dsts = [], [] for i_r, row_adj in enumerate(adj_mat): for j, val2 in enumerate(row_adj): if val2: srcs.append(i_r) dsts.append(j) if len(srcs) > 0: edge_index = [srcs, dsts] E = len(srcs) edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)] else: # edge_indices_raw can be: # - list of pairs [[u,v], ...] # - two lists [[srcs], [dsts]] srcs, dsts = [], [] if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list): first = edge_indices_raw[0] if len(first) == 2 and isinstance(first[0], int): # list of pairs try: srcs = [int(p[0]) for p in edge_indices_raw] dsts = [int(p[1]) for p in edge_indices_raw] except Exception: srcs, dsts = [], [] else: # two lists try: srcs = [int(x) for x in edge_indices_raw[0]] dsts = [int(x) for x in edge_indices_raw[1]] except Exception: srcs, dsts = [], [] if len(srcs) == 0 and isinstance(edge_indices_raw, list) and all( isinstance(p, (list, tuple)) and len(p) == 2 for p in edge_indices_raw ): srcs = [int(p[0]) for p in edge_indices_raw] dsts = [int(p[1]) for p in edge_indices_raw] if len(srcs) > 0: edge_index = [srcs, dsts] if edge_features_raw and isinstance(edge_features_raw, list): bond_types, stereos, is_conjs = [], [], [] for ef in edge_features_raw: bt = safe_get(ef, "bond_type", 0) st = safe_get(ef, "stereo", 0) ic = safe_get(ef, "is_conjugated", False) bond_types.append(float(bt)) stereos.append(float(st)) is_conjs.append(float(1.0 if ic else 0.0)) edge_attr = list(zip(bond_types, stereos, is_conjs)) else: E = len(srcs) edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)] if edge_index is not None: gine_sample = { "node_atomic": atomic_nums, "node_chirality": chirality_vals, "node_charge": formal_charges, "edge_index": edge_index, "edge_attr": edge_attr, } # -------- Geometry / SchNet modality -------- if has_geometry and schnet_sample is None: val = row.get("geometry", "") try: geom = ( json.loads(val) if isinstance(val, str) and val.strip() != "" else (val if not isinstance(val, str) else None) ) conf = geom.get("best_conformer") if isinstance(geom, dict) else None if conf: atomic = conf.get("atomic_numbers", []) coords = conf.get("coordinates", []) if len(atomic) == len(coords) and len(atomic) > 0: schnet_sample = {"atomic": atomic, "coords": coords} except Exception: schnet_sample = None # -------- Fingerprint modality -------- if has_fp: fpval = row.get("fingerprints", "") if fpval is None or (isinstance(fpval, str) and fpval.strip() == ""): fp_sample = [0] * FP_LENGTH else: fp_json = None try: fp_json = json.loads(fpval) if isinstance(fpval, str) else fpval except Exception: try: fp_json = json.loads(str(fpval).replace("'", '"')) except Exception: parts = [p.strip().strip('"').strip("'") for p in str(fpval).split(",")] bits = [1 if p in ("1", "True", "true") else 0 for p in parts[:FP_LENGTH]] if len(bits) < FP_LENGTH: bits += [0] * (FP_LENGTH - len(bits)) fp_sample = bits if fp_sample is None: bits = ( safe_get(fp_json, "morgan_r3_bits", None) if isinstance(fp_json, dict) else (fp_json if isinstance(fp_json, list) else None) ) if bits is None: fp_sample = [0] * FP_LENGTH else: normalized = [] for b in bits: if isinstance(b, str): b_clean = b.strip().strip('"').strip("'") normalized.append(1 if b_clean in ("1", "True", "true") else 0) elif isinstance(b, (int, np.integer)): normalized.append(1 if int(b) != 0 else 0) else: normalized.append(0) if len(normalized) >= FP_LENGTH: break if len(normalized) < FP_LENGTH: normalized.extend([0] * (FP_LENGTH - len(normalized))) fp_sample = normalized[:FP_LENGTH] # -------- PSMILES modality -------- if has_psmiles: s = row.get("psmiles", "") psmiles_raw = "" if s is None else str(s) # Require at least 2 modalities modalities_present = sum( [1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]] ) if modalities_present >= 2: sample = { "gine": gine_sample, "schnet": schnet_sample, "fp": fp_sample, "psmiles_raw": psmiles_raw, } sample_path = os.path.join(preproc_dir, f"sample_{sample_idx:08d}.pt") try: torch.save(sample, sample_path) except Exception as save_e: print("Warning: failed to torch.save sample:", save_e) # fallback JSON for debugging try: with open(sample_path + ".json", "w") as fjson: json.dump(sample, fjson) except Exception: pass sample_idx += 1 rows_written += 1 if rows_written >= target_rows: break print(f"Wrote {sample_idx} sample files to {preproc_dir}.") return [str(p) for p in sorted(Path(preproc_dir).glob("sample_*.pt"))] # ============================================================================= # Dataset + collate # ============================================================================= class LazyMultimodalDataset(Dataset): """ Lazily loads per-sample files from disk and converts them into tensors. Each sample file is expected to contain: - gine: dict or None - schnet: dict or None - fp: list[int] or tensor - psmiles_raw: str """ def __init__(self, sample_file_list: List[str], tokenizer, fp_length: int = FP_LENGTH, psmiles_max_len: int = PSMILES_MAX_LEN): self.files = sample_file_list self.tokenizer = tokenizer self.fp_length = fp_length self.psmiles_max_len = psmiles_max_len def __len__(self) -> int: return len(self.files) def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]: sample_path = self.files[idx] # prefer torch.load if .pt, else try json if sample_path.endswith(".pt"): sample = torch.load(sample_path, map_location="cpu") else: with open(sample_path, "r") as f: sample = json.load(f) # ---- GINE tensors ---- gine_raw = sample.get("gine", None) if gine_raw: node_atomic = torch.tensor(gine_raw.get("node_atomic", []), dtype=torch.long) node_chirality = torch.tensor(gine_raw.get("node_chirality", []), dtype=torch.float) node_charge = torch.tensor(gine_raw.get("node_charge", []), dtype=torch.float) if gine_raw.get("edge_index", None) is not None: edge_index = torch.tensor(gine_raw["edge_index"], dtype=torch.long) else: edge_index = torch.tensor([[], []], dtype=torch.long) ea_raw = gine_raw.get("edge_attr", None) if ea_raw: edge_attr = torch.tensor(ea_raw, dtype=torch.float) else: edge_attr = torch.zeros((edge_index.size(1), 3), dtype=torch.float) gine_item = { "z": node_atomic, "chirality": node_chirality, "formal_charge": node_charge, "edge_index": edge_index, "edge_attr": edge_attr, } else: gine_item = { "z": torch.tensor([], dtype=torch.long), "chirality": torch.tensor([], dtype=torch.float), "formal_charge": torch.tensor([], dtype=torch.float), "edge_index": torch.tensor([[], []], dtype=torch.long), "edge_attr": torch.zeros((0, 3), dtype=torch.float), } # ---- SchNet tensors ---- schnet_raw = sample.get("schnet", None) if schnet_raw: s_z = torch.tensor(schnet_raw.get("atomic", []), dtype=torch.long) s_pos = torch.tensor(schnet_raw.get("coords", []), dtype=torch.float) schnet_item = {"z": s_z, "pos": s_pos} else: schnet_item = {"z": torch.tensor([], dtype=torch.long), "pos": torch.tensor([], dtype=torch.float)} # ---- Fingerprint tensors ---- fp_raw = sample.get("fp", None) if fp_raw is None: fp_vec = torch.zeros((self.fp_length,), dtype=torch.long) else: if isinstance(fp_raw, (list, tuple)): arr = list(fp_raw)[:self.fp_length] if len(arr) < self.fp_length: arr = arr + [0] * (self.fp_length - len(arr)) fp_vec = torch.tensor(arr, dtype=torch.long) elif isinstance(fp_raw, torch.Tensor): fp_vec = fp_raw.clone().to(torch.long) else: fp_vec = torch.zeros((self.fp_length,), dtype=torch.long) # ---- PSMILES tensors ---- psm_raw = sample.get("psmiles_raw", "") or "" enc = self.tokenizer(psm_raw, truncation=True, padding="max_length", max_length=self.psmiles_max_len) p_input_ids = torch.tensor(enc["input_ids"], dtype=torch.long) p_attn = torch.tensor(enc["attention_mask"], dtype=torch.bool) return { "gine": { "z": gine_item["z"], "chirality": gine_item["chirality"], "formal_charge": gine_item["formal_charge"], "edge_index": gine_item["edge_index"], "edge_attr": gine_item["edge_attr"], "num_nodes": int(gine_item["z"].size(0)) if gine_item["z"].numel() > 0 else 0, }, "schnet": {"z": schnet_item["z"], "pos": schnet_item["pos"]}, "fp": {"input_ids": fp_vec}, "psmiles": {"input_ids": p_input_ids, "attention_mask": p_attn}, } def multimodal_collate(batch_list: List[Dict[str, Dict[str, torch.Tensor]]]) -> Dict[str, Dict[str, torch.Tensor]]: """ Collate a list of LazyMultimodalDataset samples into a single multimodal batch. Output keys: - gine: {z, chirality, formal_charge, edge_index, edge_attr, batch} - schnet: {z, pos, batch} - fp: {input_ids, attention_mask} - psmiles: {input_ids, attention_mask} """ # ---- GINE batching ---- all_z, all_ch, all_fc = [], [], [] all_edge_index, all_edge_attr = [], [] batch_mapping = [] node_offset = 0 for i, item in enumerate(batch_list): g = item["gine"] z = g["z"] n = z.size(0) all_z.append(z) all_ch.append(g["chirality"]) all_fc.append(g["formal_charge"]) batch_mapping.append(torch.full((n,), i, dtype=torch.long)) if g["edge_index"] is not None and g["edge_index"].numel() > 0: ei_offset = g["edge_index"] + node_offset all_edge_index.append(ei_offset) ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3) all_edge_attr.append(ea) node_offset += n if len(all_z) == 0: z_batch = torch.tensor([], dtype=torch.long) ch_batch = torch.tensor([], dtype=torch.float) fc_batch = torch.tensor([], dtype=torch.float) batch_batch = torch.tensor([], dtype=torch.long) edge_index_batched = torch.empty((2, 0), dtype=torch.long) edge_attr_batched = torch.zeros((0, 3), dtype=torch.float) else: z_batch = torch.cat(all_z, dim=0) ch_batch = torch.cat(all_ch, dim=0) fc_batch = torch.cat(all_fc, dim=0) batch_batch = torch.cat(batch_mapping, dim=0) if len(all_edge_index) > 0: edge_index_batched = torch.cat(all_edge_index, dim=1) edge_attr_batched = torch.cat(all_edge_attr, dim=0) else: edge_index_batched = torch.empty((2, 0), dtype=torch.long) edge_attr_batched = torch.zeros((0, 3), dtype=torch.float) # ---- SchNet batching ---- all_sz, all_pos, schnet_batch = [], [], [] for i, item in enumerate(batch_list): s = item["schnet"] s_z = s["z"] s_pos = s["pos"] if s_z.numel() == 0: continue all_sz.append(s_z) all_pos.append(s_pos) schnet_batch.append(torch.full((s_z.size(0),), i, dtype=torch.long)) if len(all_sz) == 0: s_z_batch = torch.tensor([], dtype=torch.long) s_pos_batch = torch.tensor([], dtype=torch.float) s_batch_batch = torch.tensor([], dtype=torch.long) else: s_z_batch = torch.cat(all_sz, dim=0) s_pos_batch = torch.cat(all_pos, dim=0) s_batch_batch = torch.cat(schnet_batch, dim=0) # ---- FP batching ---- fp_ids = torch.stack( [ item["fp"]["input_ids"] if isinstance(item["fp"]["input_ids"], torch.Tensor) else torch.tensor(item["fp"]["input_ids"], dtype=torch.long) for item in batch_list ], dim=0 ) fp_attn = torch.ones_like(fp_ids, dtype=torch.bool) # ---- PSMILES batching ---- p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch_list], dim=0) p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch_list], dim=0) return { "gine": { "z": z_batch, "chirality": ch_batch, "formal_charge": fc_batch, "edge_index": edge_index_batched, "edge_attr": edge_attr_batched, "batch": batch_batch, }, "schnet": {"z": s_z_batch, "pos": s_pos_batch, "batch": s_batch_batch}, "fp": {"input_ids": fp_ids, "attention_mask": fp_attn}, "psmiles": {"input_ids": p_ids, "attention_mask": p_attn}, } def build_dataloaders( sample_files: List[str], tokenizer, train_batch_size: int, eval_batch_size: int, seed: int = 42 ) -> Tuple[DataLoader, DataLoader, torch.utils.data.Subset, torch.utils.data.Subset]: """ Create train/val subsets and corresponding DataLoaders. """ dataset = LazyMultimodalDataset(sample_files, tokenizer, fp_length=FP_LENGTH, psmiles_max_len=PSMILES_MAX_LEN) train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=seed) train_subset = torch.utils.data.Subset(dataset, train_idx) val_subset = torch.utils.data.Subset(dataset, val_idx) train_loader = DataLoader( train_subset, batch_size=train_batch_size, shuffle=True, collate_fn=multimodal_collate, num_workers=0, drop_last=False, ) val_loader = DataLoader( val_subset, batch_size=eval_batch_size, shuffle=False, collate_fn=multimodal_collate, num_workers=0, drop_last=False, ) return train_loader, val_loader, train_subset, val_subset # ============================================================================= # Multimodal contrastive model # ============================================================================= class MultimodalContrastiveModel(nn.Module): """ Wraps unimodal encoders and computes: - InfoNCE between masked modality embedding vs mean anchor of other modalities - Optional reconstruction losses for masked tokens/atoms when labels are present """ def __init__( self, gine_encoder: Optional[GineEncoder], schnet_encoder: Optional[NodeSchNetWrapper], fp_encoder: Optional[FingerprintEncoder], psmiles_encoder: Optional[PSMILESDebertaEncoder], emb_dim: int = 600, ): super().__init__() self.gine = gine_encoder self.schnet = schnet_encoder self.fp = fp_encoder self.psmiles = psmiles_encoder self.proj_gine = nn.Linear(getattr(self.gine, "pool_proj").out_features if self.gine is not None else emb_dim, emb_dim) if self.gine is not None else None self.proj_schnet = nn.Linear(getattr(self.schnet, "pool_proj").out_features if self.schnet is not None else emb_dim, emb_dim) if self.schnet is not None else None self.proj_fp = nn.Linear(getattr(self.fp, "pool_proj").out_features if self.fp is not None else emb_dim, emb_dim) if self.fp is not None else None self.proj_psmiles = nn.Linear(getattr(self.psmiles, "pool_proj").out_features if self.psmiles is not None else emb_dim, emb_dim) if self.psmiles is not None else None self.temperature = TEMPERATURE self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") def encode(self, batch_mods: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Compute normalized projected embeddings for available modalities.""" embs = {} if "gine" in batch_mods and self.gine is not None: g = batch_mods["gine"] emb_g = self.gine(g["z"], g["chirality"], g["formal_charge"], g["edge_index"], g["edge_attr"], g.get("batch", None)) embs["gine"] = F.normalize(self.proj_gine(emb_g), dim=-1) if "schnet" in batch_mods and self.schnet is not None: s = batch_mods["schnet"] emb_s = self.schnet(s["z"], s["pos"], s.get("batch", None)) embs["schnet"] = F.normalize(self.proj_schnet(emb_s), dim=-1) if "fp" in batch_mods and self.fp is not None: f = batch_mods["fp"] emb_f = self.fp(f["input_ids"], f.get("attention_mask", None)) embs["fp"] = F.normalize(self.proj_fp(emb_f), dim=-1) if "psmiles" in batch_mods and self.psmiles is not None: p = batch_mods["psmiles"] emb_p = self.psmiles(p["input_ids"], p.get("attention_mask", None)) embs["psmiles"] = F.normalize(self.proj_psmiles(emb_p), dim=-1) return embs def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str): """ Compute total loss = InfoNCE + REC_LOSS_WEIGHT * reconstruction_loss """ device = next(self.parameters()).device embs = self.encode(batch_mods) info = {} if mask_target not in embs: return torch.tensor(0.0, device=device), {"batch_size": 0} target = embs[mask_target] other_keys = [k for k in embs.keys() if k != mask_target] if len(other_keys) == 0: return torch.tensor(0.0, device=device), {"batch_size": target.size(0)} anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0) logits = torch.matmul(anchor, target.T) / self.temperature B = logits.size(0) labels = torch.arange(B, device=logits.device) info_nce_loss = F.cross_entropy(logits, labels) info["info_nce_loss"] = float(info_nce_loss.detach().cpu().item()) # Optional reconstruction terms rec_losses = [] rec_details = {} # GINE node reconstruction (atomic ids) if labels present try: if "gine" in batch_mods and self.gine is not None: gm = batch_mods["gine"] labels_nodes = gm.get("labels", None) if labels_nodes is not None: node_logits = self.gine.node_logits(gm["z"], gm["chirality"], gm["formal_charge"], gm["edge_index"], gm["edge_attr"]) if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0): loss_gine = self.ce_loss(node_logits, labels_nodes.to(node_logits.device)) rec_losses.append(loss_gine) rec_details["gine_rec_loss"] = float(loss_gine.detach().cpu().item()) except Exception as e: print("Warning: GINE reconstruction loss computation failed:", e) # SchNet node reconstruction if labels present try: if "schnet" in batch_mods and self.schnet is not None: sm = batch_mods["schnet"] labels_nodes = sm.get("labels", None) if labels_nodes is not None: node_logits = self.schnet.node_logits(sm["z"], sm["pos"], sm.get("batch", None)) if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0): loss_schnet = self.ce_loss(node_logits, labels_nodes.to(node_logits.device)) rec_losses.append(loss_schnet) rec_details["schnet_rec_loss"] = float(loss_schnet.detach().cpu().item()) except Exception as e: print("Warning: SchNet reconstruction loss computation failed:", e) # FP token reconstruction if labels present try: if "fp" in batch_mods and self.fp is not None: fm = batch_mods["fp"] labels_fp = fm.get("labels", None) if labels_fp is not None: token_logits = self.fp.token_logits(fm["input_ids"], fm.get("attention_mask", None)) Bf, Lf, V = token_logits.shape logits2 = token_logits.view(-1, V) labels2 = labels_fp.view(-1).to(logits2.device) loss_fp = self.ce_loss(logits2, labels2) rec_losses.append(loss_fp) rec_details["fp_rec_loss"] = float(loss_fp.detach().cpu().item()) except Exception as e: print("Warning: FP reconstruction loss computation failed:", e) # PSMILES MLM loss if labels present try: if "psmiles" in batch_mods and self.psmiles is not None: pm = batch_mods["psmiles"] labels_ps = pm.get("labels", None) if labels_ps is not None: loss_ps = self.psmiles.token_logits(pm["input_ids"], pm.get("attention_mask", None), labels=labels_ps) if isinstance(loss_ps, torch.Tensor): rec_losses.append(loss_ps) rec_details["psmiles_mlm_loss"] = float(loss_ps.detach().cpu().item()) except Exception as e: print("Warning: PSMILES MLM loss computation failed:", e) if len(rec_losses) > 0: rec_loss_total = sum(rec_losses) / len(rec_losses) info["reconstruction_loss"] = float(rec_loss_total.detach().cpu().item()) total_loss = info_nce_loss + REC_LOSS_WEIGHT * rec_loss_total info["total_loss"] = float(total_loss.detach().cpu().item()) info.update(rec_details) else: total_loss = info_nce_loss info["reconstruction_loss"] = 0.0 info["total_loss"] = float(total_loss.detach().cpu().item()) return total_loss, info # ============================================================================= # Masking utilities # ============================================================================= def mask_batch_for_modality(batch: dict, modality: str, tokenizer, p_mask: float = P_MASK) -> dict: """ Apply BERT-style masking to the selected modality and attach `labels`. """ b = {} # ---------------- GINE ---------------- if "gine" in batch: z = batch["gine"]["z"].clone() chir = batch["gine"]["chirality"].clone() fc = batch["gine"]["formal_charge"].clone() edge_index = batch["gine"]["edge_index"] edge_attr = batch["gine"]["edge_attr"] batch_map = batch["gine"].get("batch", None) n_nodes = z.size(0) dev = z.device is_selected = torch.rand(n_nodes, device=dev) < p_mask if is_selected.numel() > 0 and is_selected.all(): is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False labels_z = torch.full_like(z, fill_value=-100) if is_selected.any(): sel_idx = torch.nonzero(is_selected).squeeze(-1) if sel_idx.dim() == 0: sel_idx = sel_idx.unsqueeze(0) labels_z[is_selected] = z[is_selected] rand_atomic = torch.randint(1, MAX_ATOMIC_Z + 1, (sel_idx.size(0),), dtype=torch.long, device=dev) probs = torch.rand(sel_idx.size(0), device=dev) mask_choice = probs < 0.8 rand_choice = (probs >= 0.8) & (probs < 0.9) if mask_choice.any(): z[sel_idx[mask_choice]] = MASK_ATOM_ID if rand_choice.any(): z[sel_idx[rand_choice]] = rand_atomic[rand_choice] b["gine"] = { "z": z, "chirality": chir, "formal_charge": fc, "edge_index": edge_index, "edge_attr": edge_attr, "batch": batch_map, "labels": labels_z, } # ---------------- SchNet ---------------- if "schnet" in batch: z = batch["schnet"]["z"].clone() pos = batch["schnet"]["pos"].clone() batch_map = batch["schnet"].get("batch", None) n_nodes = z.size(0) dev = z.device is_selected = torch.rand(n_nodes, device=dev) < p_mask if is_selected.numel() > 0 and is_selected.all(): is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False labels_z = torch.full((n_nodes,), -100, dtype=torch.long, device=dev) if is_selected.any(): sel_idx = torch.nonzero(is_selected).squeeze(-1) if sel_idx.dim() == 0: sel_idx = sel_idx.unsqueeze(0) labels_z[is_selected] = z[is_selected] probs_c = torch.rand(sel_idx.size(0), device=dev) noisy_choice = probs_c < 0.8 randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9) if noisy_choice.any(): idx = sel_idx[noisy_choice] noise = torch.randn((idx.size(0), 3), device=pos.device) * 0.5 pos[idx] = pos[idx] + noise if randpos_choice.any(): idx = sel_idx[randpos_choice] mins = pos.min(dim=0).values maxs = pos.max(dim=0).values randpos = (torch.rand((idx.size(0), 3), device=pos.device) * (maxs - mins)) + mins pos[idx] = randpos b["schnet"] = {"z": z, "pos": pos, "batch": batch_map, "labels": labels_z} # ---------------- FP ---------------- if "fp" in batch: input_ids = batch["fp"]["input_ids"].clone() attn = batch["fp"].get("attention_mask", torch.ones_like(input_ids, dtype=torch.bool)) B, L = input_ids.shape dev = input_ids.device labels_z = torch.full_like(input_ids, -100) for i in range(B): sel = torch.rand(L, device=dev) < p_mask if sel.numel() > 0 and sel.all(): sel[torch.randint(0, L, (1,), device=dev)] = False sel_idx = torch.nonzero(sel).squeeze(-1) if sel_idx.numel() > 0: if sel_idx.dim() == 0: sel_idx = sel_idx.unsqueeze(0) labels_z[i, sel_idx] = input_ids[i, sel_idx] probs = torch.rand(sel_idx.size(0), device=dev) mask_choice = probs < 0.8 rand_choice = (probs >= 0.8) & (probs < 0.9) if mask_choice.any(): input_ids[i, sel_idx[mask_choice]] = MASK_TOKEN_ID_FP if rand_choice.any(): rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long, device=dev) input_ids[i, sel_idx[rand_choice]] = rand_bits b["fp"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z} # ---------------- PSMILES ---------------- if "psmiles" in batch: input_ids = batch["psmiles"]["input_ids"].clone() attn = batch["psmiles"]["attention_mask"].clone() B, L = input_ids.shape dev = input_ids.device labels_z = torch.full_like(input_ids, -100) # If tokenizer is unavailable, keep labels=-100 (no MLM loss) if tokenizer is None: b["psmiles"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z} else: mask_token_id = tokenizer.mask_token_id if getattr(tokenizer, "mask_token_id", None) is not None else getattr(tokenizer, "vocab", {}).get("", 1) for i in range(B): sel = torch.rand(L, device=dev) < p_mask if sel.numel() > 0 and sel.all(): sel[torch.randint(0, L, (1,), device=dev)] = False sel_idx = torch.nonzero(sel).squeeze(-1) if sel_idx.numel() > 0: if sel_idx.dim() == 0: sel_idx = sel_idx.unsqueeze(0) labels_z[i, sel_idx] = input_ids[i, sel_idx] probs = torch.rand(sel_idx.size(0), device=dev) mask_choice = probs < 0.8 rand_choice = (probs >= 0.8) & (probs < 0.9) if mask_choice.any(): input_ids[i, sel_idx[mask_choice]] = mask_token_id if rand_choice.any(): rand_ids = torch.randint(0, getattr(tokenizer, "vocab_size", 300), (rand_choice.sum().item(),), dtype=torch.long, device=dev) input_ids[i, sel_idx[rand_choice]] = rand_ids b["psmiles"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z} return b def mm_batch_to_model_input(masked_batch: dict) -> dict: """ Normalize the masked batch dict into the exact structure expected by MultimodalContrastiveModel. """ mm = {} if "gine" in masked_batch: gm = masked_batch["gine"] mm["gine"] = { "z": gm["z"], "chirality": gm["chirality"], "formal_charge": gm["formal_charge"], "edge_index": gm["edge_index"], "edge_attr": gm["edge_attr"], "batch": gm.get("batch", None), "labels": gm.get("labels", None), } if "schnet" in masked_batch: sm = masked_batch["schnet"] mm["schnet"] = {"z": sm["z"], "pos": sm["pos"], "batch": sm.get("batch", None), "labels": sm.get("labels", None)} if "fp" in masked_batch: fm = masked_batch["fp"] mm["fp"] = {"input_ids": fm["input_ids"], "attention_mask": fm.get("attention_mask", None), "labels": fm.get("labels", None)} if "psmiles" in masked_batch: pm = masked_batch["psmiles"] mm["psmiles"] = {"input_ids": pm["input_ids"], "attention_mask": pm.get("attention_mask", None), "labels": pm.get("labels", None)} return mm # ============================================================================= # Evaluation # ============================================================================= def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader: DataLoader, device: torch.device, tokenizer, mask_target: str = "fp") -> Dict[str, float]: """ Contrastive-only evaluation: - masks one modality - computes InfoNCE logits = anchor·target / T - reports eval_loss, top1 acc, weighted F1 """ model.eval() total_loss = 0.0 total_examples = 0 acc_sum = 0.0 f1_sum = 0.0 with torch.no_grad(): for batch in val_loader: masked_batch = mask_batch_for_modality(batch, mask_target, tokenizer=tokenizer, p_mask=P_MASK) # Move tensors to device for k in masked_batch: for subk in masked_batch[k]: if isinstance(masked_batch[k][subk], torch.Tensor): masked_batch[k][subk] = masked_batch[k][subk].to(device) mm_in = mm_batch_to_model_input(masked_batch) embs = model.encode(mm_in) if mask_target not in embs: continue target = embs[mask_target] other_keys = [k for k in embs.keys() if k != mask_target] if len(other_keys) == 0: continue anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0) logits = torch.matmul(anchor, target.T) / model.temperature B = logits.size(0) labels = torch.arange(B, device=logits.device) loss = F.cross_entropy(logits, labels) total_loss += loss.item() * B total_examples += B preds = logits.argmax(dim=1) acc = (preds == labels).float().mean().item() acc_sum += acc * B # Weighted F1 over instance IDs try: labels_np = labels.cpu().numpy() preds_np = preds.cpu().numpy() if len(np.unique(labels_np)) < 2: batch_f1 = float(acc) else: batch_f1 = f1_score(labels_np, preds_np, average="weighted") except Exception: batch_f1 = float(acc) f1_sum += batch_f1 * B if total_examples == 0: return {"eval_loss": float("nan"), "eval_accuracy": 0.0, "eval_f1_weighted": 0.0} return { "eval_loss": total_loss / total_examples, "eval_accuracy": acc_sum / total_examples, "eval_f1_weighted": f1_sum / total_examples, } # ============================================================================= # HF wrapper + collator + trainer # ============================================================================= class HFMultimodalModule(nn.Module): """ HuggingFace Trainer-facing wrapper: - Receives a full multimodal batch - Randomly masks one modality (provided by collator) inside forward - Returns a dict compatible with Trainer (loss, logits, labels) """ def __init__(self, mm_model: MultimodalContrastiveModel, tokenizer): super().__init__() self.mm = mm_model self._tokenizer = tokenizer def forward(self, **kwargs): if "batch" in kwargs: batch = kwargs["batch"] mask_target = kwargs.get("mask_target", "fp") else: modality_keys = ["gine", "schnet", "fp", "psmiles"] found = {k: v for k, v in kwargs.items() if k in modality_keys} if len(found) > 0: batch = {k: found[k] for k in found} mask_target = kwargs.get("mask_target", "fp") else: raise ValueError( "HFMultimodalModule.forward could not find 'batch' nor modality keys in inputs. " f"Inputs keys: {list(kwargs.keys())}" ) masked_batch = mask_batch_for_modality(batch, mask_target, tokenizer=self._tokenizer, p_mask=P_MASK) device = next(self.parameters()).device for k in masked_batch: for subk in list(masked_batch[k].keys()): val = masked_batch[k][subk] if isinstance(val, torch.Tensor): masked_batch[k][subk] = val.to(device) mm_in = mm_batch_to_model_input(masked_batch) loss, info = self.mm(mm_in, mask_target) logits = None labels = None try: with torch.no_grad(): embs = self.mm.encode(mm_in) if mask_target in embs: target = embs[mask_target] other_keys = [k for k in embs.keys() if k != mask_target] if len(other_keys) > 0: anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0) logits = torch.matmul(anchor, target.T) / self.mm.temperature B = logits.size(0) labels = torch.arange(B, device=logits.device) except Exception as e: print("Warning: failed to compute logits/labels inside HFMultimodalModule.forward:", e) logits = None labels = None eval_loss = loss.detach() if isinstance(loss, torch.Tensor) else torch.tensor(float(loss), device=device) out = {"loss": loss, "eval_loss": eval_loss} if logits is not None: out["logits"] = logits if labels is not None: out["labels"] = labels out["mm_info"] = info return out class ContrastiveDataCollator: """ Collator used by Trainer: - If given raw samples (list of dicts), it calls multimodal_collate - Then selects a random modality to mask (mask_target) """ def __init__(self, mask_prob: float = P_MASK, modalities: Optional[List[str]] = None): self.mask_prob = mask_prob self.modalities = modalities if modalities is not None else ["gine", "schnet", "fp", "psmiles"] def __call__(self, features): if isinstance(features, dict): collated = features mask_target = random.choice([m for m in self.modalities if m in collated]) return {"batch": collated, "mask_target": mask_target} if isinstance(features, (list, tuple)) and len(features) > 0: first = features[0] if isinstance(first, dict) and "gine" in first: collated = multimodal_collate(list(features)) mask_target = random.choice([m for m in self.modalities if m in collated]) return {"batch": collated, "mask_target": mask_target} if isinstance(first, dict) and "batch" in first: collated = first["batch"] mask_target = first.get("mask_target", random.choice([m for m in self.modalities if m in collated])) return {"batch": collated, "mask_target": mask_target} print("ContrastiveDataCollator received unexpected 'features' shape/type.") raise ValueError("ContrastiveDataCollator could not collate input. Expected list[dict] with 'gine' key or already-collated dict.") class VerboseTrainingCallback(TrainerCallback): """ Console-first training callback with early stopping on eval_loss. """ def __init__(self, patience: int = 10): self.start_time = time.time() self.epoch_start_time = time.time() self._last_train_loss = None self.best_val_loss = float("inf") self.best_epoch = 0 self.epochs_no_improve = 0 self.patience = patience self.trainer_ref = None def save_best_model(self, output_dir_suffix: str = "best"): if self.trainer_ref is None: return try: ckpt_dir = os.path.join(OUTPUT_DIR, output_dir_suffix) os.makedirs(ckpt_dir, exist_ok=True) self.trainer_ref._save(ckpt_dir) print(f"Saved best model checkpoint to {ckpt_dir}") except Exception as e: try: self.trainer_ref.save_model(os.path.join(OUTPUT_DIR, output_dir_suffix)) print(f"Saved best model (fallback) to {os.path.join(OUTPUT_DIR, output_dir_suffix)}") except Exception as e2: print("Warning: failed to save best model:", e, e2) def on_train_begin(self, args, state, control, **kwargs): self.start_time = time.time() print("=" * 80) print(" STARTING MULTIMODAL CONTRASTIVE LEARNING TRAINING") print("=" * 80) model = kwargs.get("model") if model is not None: total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) non_trainable_params = total_params - trainable_params print(" MODEL PARAMETERS:") print(f" Total Parameters: {total_params:,}") print(f" Trainable Parameters: {trainable_params:,}") print(f" Non-trainable Parameters: {non_trainable_params:,}") print(f" Training Progress: 0/{args.num_train_epochs} epochs") print("=" * 80) def on_epoch_begin(self, args, state, control, **kwargs): self.epoch_start_time = time.time() current_epoch = state.epoch if state is not None else 0.0 print(f" Epoch {current_epoch + 1:.1f}/{args.num_train_epochs} Starting...") def on_epoch_end(self, args, state, control, **kwargs): train_loss = None for log in reversed(state.log_history): if isinstance(log, dict) and "loss" in log and float(log.get("loss", 0)) != 0.0: train_loss = log["loss"] break self._last_train_loss = train_loss def on_log(self, args, state, control, logs=None, **kwargs): if logs is not None and "loss" in logs: current_step = state.global_step current_epoch = state.epoch try: steps_per_epoch = max(1, len(train_loader) // args.gradient_accumulation_steps) except Exception: steps_per_epoch = 1 if current_step % max(1, steps_per_epoch // 10) == 0: progress = current_epoch + (current_step % steps_per_epoch) / steps_per_epoch print(f" Step {current_step:4d} | Epoch {progress:.1f} | Train Loss: {logs['loss']:.6f}") def on_evaluate(self, args, state, control, metrics=None, **kwargs): current_epoch = state.epoch if state is not None else 0.0 epoch_time = time.time() - self.epoch_start_time hf_metrics = metrics if metrics is not None else kwargs.get("metrics", None) hf_eval_loss = None hf_train_loss = self._last_train_loss if hf_metrics is not None: hf_eval_loss = hf_metrics.get("eval_loss", hf_metrics.get("loss", None)) if hf_train_loss is None: hf_train_loss = hf_metrics.get("train_loss", hf_train_loss) cl_metrics = {} try: model = kwargs.get("model", None) if model is not None: cl_model = model.mm if hasattr(model, "mm") else model cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp") else: cl_metrics = evaluate_multimodal(multimodal_model, val_loader, device, tokenizer, mask_target="fp") except Exception as e: print("Warning: evaluate_multimodal inside callback failed:", e) if hf_eval_loss is None: hf_eval_loss = cl_metrics.get("eval_loss", None) val_acc = cl_metrics.get("eval_accuracy", "N/A") val_f1 = cl_metrics.get("eval_f1_weighted", "N/A") print(f" EPOCH {current_epoch + 1:.1f} RESULTS:") if hf_train_loss is not None: try: print(f" Train Loss (HF reported): {hf_train_loss:.6f}") except Exception: print(f" Train Loss (HF reported): {hf_train_loss}") else: print(" Train Loss (HF reported): N/A") if hf_eval_loss is not None: try: print(f" Eval Loss (HF reported): {hf_eval_loss:.6f}") except Exception: print(f" Eval Loss (HF reported): {hf_eval_loss}") else: print(" Eval Loss (HF reported): N/A") if isinstance(val_acc, float): print(f" Eval Acc (CL evaluator): {val_acc:.6f}") else: print(f" Eval Acc (CL evaluator): {val_acc}") if isinstance(val_f1, float): print(f" Eval F1 Weighted (CL evaluator): {val_f1:.6f}") else: print(f" Eval F1 Weighted (CL evaluator): {val_f1}") current_val = hf_eval_loss if hf_eval_loss is not None else float("inf") if current_val < self.best_val_loss - 1e-6: self.best_val_loss = current_val self.best_epoch = current_epoch self.epochs_no_improve = 0 try: self.save_best_model("best") except Exception as e: print("Warning: saving best model failed:", e) else: self.epochs_no_improve += 1 if self.epochs_no_improve >= self.patience: print(f"Early stopping: no improvement in val_loss for {self.patience} epochs.") control.should_training_stop = True print(f" Epoch Training Time: {epoch_time:.2f}s") print(f" Best Val Loss so far: {self.best_val_loss}") print(f" Epochs since improvement: {self.epochs_no_improve}/{self.patience}") print("-" * 50) def on_train_end(self, args, state, control, **kwargs): total_time = time.time() - self.start_time print("=" * 80) print(" TRAINING COMPLETED") print("=" * 80) print(f" Total Training Time: {total_time:.2f}s") if state is not None: try: print(f" Total Epochs Completed: {state.epoch + 1:.1f}") except Exception: pass print("=" * 80) class CLTrainer(Trainer): """ Custom Trainer: - evaluate(): merges HF eval with contrastive evaluator - _save(): saves a state_dict under pytorch_model.bin - _load_best_model(): loads best pytorch_model.bin """ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"): try: metrics = super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) or {} except Exception as e: print("Warning: super().evaluate() raised an exception. Falling back to CL-only evaluator.") import traceback traceback.print_exc() try: cl_model = self.model.mm if hasattr(self.model, "mm") else self.model cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp") metrics = {k: float(v) if isinstance(v, (float, int, np.floating, np.integer)) else v for k, v in cl_metrics.items()} metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0) except Exception as e2: print("Fallback evaluate_multimodal failed as well:", e2) traceback.print_exc() metrics = {"eval_loss": float("nan"), "epoch": float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else 0.0} return metrics try: cl_model = self.model.mm if hasattr(self.model, "mm") else self.model cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp") except Exception as e: print("Warning: evaluate_multimodal failed inside CLTrainer.evaluate():", e) cl_metrics = {} for k, v in cl_metrics.items(): try: metrics[k] = float(v) except Exception: metrics[k] = v if "eval_loss" not in metrics and "eval_loss" in cl_metrics: try: metrics["eval_loss"] = float(cl_metrics["eval_loss"]) except Exception: metrics["eval_loss"] = cl_metrics["eval_loss"] if "epoch" not in metrics: metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0) return metrics def _save(self, output_dir: str): os.makedirs(output_dir, exist_ok=True) try: self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) except Exception: pass try: model_to_save = self.model.mm if hasattr(self.model, "mm") else self.model torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) except Exception as e: try: if hasattr(self.model, "save_pretrained"): self.model.save_pretrained(output_dir) else: raise e except Exception as e2: print("Warning: failed to save model state_dict:", e2) try: torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) except Exception: pass try: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) except Exception: pass def _load_best_model(self): best_ckpt = self.state.best_model_checkpoint if not best_ckpt: return candidate = os.path.join(best_ckpt, "pytorch_model.bin") if not os.path.exists(candidate): candidate = os.path.join(best_ckpt, "model.bin") if not os.path.exists(candidate): candidate = None if candidate is None: print(f"CLTrainer._load_best_model(): no compatible pytorch_model.bin found in {best_ckpt}; skipping load.") return try: state_dict = torch.load(candidate, map_location=self.args.device) model_to_load = self.model.mm if hasattr(self.model, "mm") else self.model model_to_load.load_state_dict(state_dict, strict=False) print(f"CLTrainer: loaded best model state_dict from {candidate}") except Exception as e: print("CLTrainer._load_best_model: failed to load state_dict using torch.load:", e) return # ============================================================================= # Model construction + weight loading # ============================================================================= def load_state_dict_if_present(model: nn.Module, ckpt_dir: str, filename: str = "pytorch_model.bin") -> None: """Load model weights if the checkpoint file exists.""" path = os.path.join(ckpt_dir, filename) if os.path.exists(path): try: model.load_state_dict(torch.load(path, map_location="cpu"), strict=False) print(f"Loaded weights from {path}") except Exception as e: print(f"Could not load weights from {path}: {e}") def build_models(device: torch.device) -> Tuple[MultimodalContrastiveModel, PSMILESDebertaEncoder]: """Instantiate unimodal encoders, optionally load best checkpoints, and assemble the multimodal model.""" # GINE gine_encoder = GineEncoder(node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z) load_state_dict_if_present(gine_encoder, BEST_GINE_DIR) gine_encoder.to(device) # SchNet schnet_encoder = NodeSchNetWrapper( hidden_channels=SCHNET_HIDDEN, num_interactions=SCHNET_NUM_INTERACTIONS, num_gaussians=SCHNET_NUM_GAUSSIANS, cutoff=SCHNET_CUTOFF, max_num_neighbors=SCHNET_MAX_NEIGHBORS, ) load_state_dict_if_present(schnet_encoder, BEST_SCHNET_DIR) schnet_encoder.to(device) # Fingerprint encoder fp_encoder = FingerprintEncoder( vocab_size=VOCAB_SIZE_FP, hidden_dim=256, seq_len=FP_LENGTH, num_layers=4, nhead=8, dim_feedforward=1024, dropout=0.1, ) load_state_dict_if_present(fp_encoder, BEST_FP_DIR) fp_encoder.to(device) # PSMILES / DeBERTa psmiles_encoder = None if os.path.isdir(BEST_PSMILES_DIR): try: psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR) print("Loaded Deberta (PSMILES) from", BEST_PSMILES_DIR) except Exception as e: print("Failed to load Deberta from saved directory:", e) if psmiles_encoder is None: psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=None) psmiles_encoder.to(device) multimodal_model = MultimodalContrastiveModel(gine_encoder, schnet_encoder, fp_encoder, psmiles_encoder, emb_dim=600) multimodal_model.to(device) return multimodal_model, psmiles_encoder # ============================================================================= # Main execution # ============================================================================= def main(): # ---- setup ---- ensure_dir(OUTPUT_DIR) ensure_dir(PREPROC_DIR) device_local = get_device() print("Device:", device_local) set_seed(42) training_args = TrainingArguments( output_dir=OUTPUT_DIR, overwrite_output_dir=True, num_train_epochs=25, per_device_train_batch_size=16, per_device_eval_batch_size=8, gradient_accumulation_steps=4, eval_strategy="epoch", logging_steps=100, learning_rate=1e-4, weight_decay=0.01, eval_accumulation_steps=1000, fp16=torch.cuda.is_available(), save_strategy="epoch", save_steps=500, disable_tqdm=False, logging_first_step=True, report_to=[], dataloader_num_workers=0, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, ) # ---- data ---- sample_files = prepare_or_load_data_streaming( csv_path=CSV_PATH, preproc_dir=PREPROC_DIR, target_rows=TARGET_ROWS, chunksize=CHUNKSIZE, ) tokenizer_local = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN) global train_loader, val_loader, multimodal_model, device, tokenizer tokenizer = tokenizer_local device = device_local train_loader, val_loader, train_subset, val_subset = build_dataloaders( sample_files=sample_files, tokenizer=tokenizer_local, train_batch_size=training_args.per_device_train_batch_size, eval_batch_size=training_args.per_device_eval_batch_size, seed=42, ) # ---- models ---- multimodal_model, _psmiles_encoder = build_models(device_local) hf_model = HFMultimodalModule(multimodal_model, tokenizer_local).to(device_local) data_collator = ContrastiveDataCollator(mask_prob=P_MASK) callback = VerboseTrainingCallback(patience=10) trainer = CLTrainer( model=hf_model, args=training_args, train_dataset=train_subset, eval_dataset=val_subset, data_collator=data_collator, callbacks=[callback], ) callback.trainer_ref = trainer # Force HF Trainer to use our prebuilt PyTorch DataLoaders trainer.get_train_dataloader = lambda dataset=None: train_loader trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader # Optimizer _optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay) total_params = sum(p.numel() for p in multimodal_model.parameters()) trainable_params = sum(p.numel() for p in multimodal_model.parameters() if p.requires_grad) non_trainable_params = total_params - trainable_params print("\n MODEL PARAMETERS:") print(f" Total Parameters: {total_params:,}") print(f" Trainable Parameters: {trainable_params:,}") print(f" Non-trainable Parameters: {non_trainable_params:,}") # Clear GPU cache if torch.cuda.is_available(): try: torch.cuda.empty_cache() except Exception: pass # ---- Train ---- training_start_time = time.time() trainer.train() training_end_time = time.time() # ---- Save best ---- best_dir = os.path.join(OUTPUT_DIR, "best") os.makedirs(best_dir, exist_ok=True) try: best_ckpt = trainer.state.best_model_checkpoint if best_ckpt: multimodal_model.load_state_dict(torch.load(os.path.join(best_ckpt, "pytorch_model.bin"), map_location=device_local), strict=False) print(f"Loaded best checkpoint from {best_ckpt} into multimodal_model for final evaluation.") torch.save(multimodal_model.state_dict(), os.path.join(best_dir, "pytorch_model.bin")) print(f" Saved best multimodal model to {os.path.join(best_dir, 'pytorch_model.bin')}") except Exception as e: print("Warning: failed to load/save best model from Trainer:", e) # ---- Final Evaluation ---- final_metrics = {} try: if trainer.state.best_model_checkpoint: trainer._load_best_model() final_metrics = trainer.evaluate(eval_dataset=val_subset) else: final_metrics = evaluate_multimodal(multimodal_model, val_loader, device_local, tokenizer_local, mask_target="fp") except Exception as e: print("Warning: final evaluation via trainer.evaluate failed, falling back to direct evaluate_multimodal:", e) final_metrics = evaluate_multimodal(multimodal_model, val_loader, device_local, tokenizer_local, mask_target="fp") print("\n" + "=" * 80) print(" FINAL TRAINING RESULTS") print("=" * 80) print(f"Total Training Time: {training_end_time - training_start_time:.2f}s") best_ckpt = trainer.state.best_model_checkpoint if hasattr(trainer.state, "best_model_checkpoint") else None print(f"Best Checkpoint: {best_ckpt if best_ckpt else '(none saved)'}") hf_eval_loss = final_metrics.get("eval_loss", float("nan")) hf_eval_acc = final_metrics.get("eval_accuracy", 0.0) hf_eval_f1 = final_metrics.get("eval_f1_weighted", 0.0) print(f"Val Loss (HF reported / trainer.evaluate): {hf_eval_loss:.4f}") print(f"Val Acc (CL evaluator): {hf_eval_acc:.4f}") print(f"Val F1 Weighted (CL evaluator): {hf_eval_f1:.4f}") print(f"Total Trainable Params: {trainable_params:,}") print(f"Total Non-trainable Params: {non_trainable_params:,}") print("=" * 80) if __name__ == "__main__": main()