Spaces:
Running
Running
| """ | |
| GINE.py | |
| GINE-based masked pretraining on polymer 2D graphs. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import time | |
| import sys | |
| import csv | |
| import argparse | |
| from typing import Any, Dict, List, Optional, Tuple | |
| # Increase max CSV field size limit | |
| csv.field_size_limit(sys.maxsize) | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import TrainingArguments, Trainer | |
| from transformers.trainer_callback import TrainerCallback | |
| from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error | |
| from torch_geometric.nn import GINEConv | |
| # --------------------------- | |
| # Configuration / Constants | |
| # --------------------------- | |
| P_MASK = 0.15 | |
| MAX_ATOMIC_Z = 85 | |
| MASK_ATOM_ID = MAX_ATOMIC_Z + 1 | |
| USE_LEARNED_WEIGHTING = True | |
| NODE_EMB_DIM = 300 | |
| EDGE_EMB_DIM = 300 | |
| NUM_GNN_LAYERS = 5 | |
| K_ANCHORS = 6 | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="GINE masked pretraining (graphs).") | |
| parser.add_argument( | |
| "--csv_path", | |
| type=str, | |
| default="/path/to/polymer_structures_unified_processed.csv", | |
| help="Processed CSV containing a JSON 'graph' column.", | |
| ) | |
| parser.add_argument("--target_rows", type=int, default=5_000_000, help="Max rows to parse.") | |
| parser.add_argument("--chunksize", type=int, default=50_000, help="CSV chunksize.") | |
| parser.add_argument("--output_dir", type=str, default="/path/to/gin_output_5M", help="Training output directory.") | |
| parser.add_argument("--num_workers", type=int, default=4, help="PyTorch DataLoader num workers.") | |
| return parser.parse_args() | |
| # --------------------------- | |
| # Helper functions | |
| # --------------------------- | |
| def safe_get(d: dict, key: str, default=None): | |
| return d[key] if (isinstance(d, dict) and key in d) else default | |
| def build_adj_list(edge_index: torch.Tensor, num_nodes: int) -> List[List[int]]: | |
| """Adjacency list for BFS shortest paths.""" | |
| adj = [[] for _ in range(num_nodes)] | |
| if edge_index is None or edge_index.numel() == 0: | |
| return adj | |
| src = edge_index[0].tolist() | |
| dst = edge_index[1].tolist() | |
| for u, v in zip(src, dst): | |
| if 0 <= u < num_nodes and 0 <= v < num_nodes: | |
| adj[u].append(v) | |
| return adj | |
| def shortest_path_lengths_hops(edge_index: torch.Tensor, num_nodes: int) -> np.ndarray: | |
| """ | |
| All-pairs shortest path lengths in hops using BFS per node. | |
| Unreachable pairs get distance INF=num_nodes+1. | |
| """ | |
| adj = build_adj_list(edge_index, num_nodes) | |
| INF = num_nodes + 1 | |
| dist_mat = np.full((num_nodes, num_nodes), INF, dtype=np.int32) | |
| for s in range(num_nodes): | |
| q = [s] | |
| dist_mat[s, s] = 0 | |
| head = 0 | |
| while head < len(q): | |
| u = q[head] | |
| head += 1 | |
| for v in adj[u]: | |
| if dist_mat[s, v] == INF: | |
| dist_mat[s, v] = dist_mat[s, u] + 1 | |
| q.append(v) | |
| return dist_mat | |
| def match_edge_attr_to_index(edge_index: torch.Tensor, edge_attr: torch.Tensor, target_dim: int = 3) -> torch.Tensor: | |
| """ | |
| Ensure edge_attr has shape [E_index, target_dim], handling common mismatches. | |
| """ | |
| E_idx = edge_index.size(1) if (edge_index is not None and edge_index.numel() > 0) else 0 | |
| if E_idx == 0: | |
| return torch.zeros((0, target_dim), dtype=torch.float) | |
| if edge_attr is None or edge_attr.numel() == 0: | |
| return torch.zeros((E_idx, target_dim), dtype=torch.float) | |
| E_attr = edge_attr.size(0) | |
| if E_attr == E_idx: | |
| if edge_attr.size(1) != target_dim: | |
| D = edge_attr.size(1) | |
| if D < target_dim: | |
| pad = torch.zeros((E_attr, target_dim - D), dtype=torch.float, device=edge_attr.device) | |
| return torch.cat([edge_attr, pad], dim=1) | |
| return edge_attr[:, :target_dim] | |
| return edge_attr | |
| if E_attr * 2 == E_idx: | |
| try: | |
| return torch.cat([edge_attr, edge_attr], dim=0) | |
| except Exception: | |
| pass | |
| reps = (E_idx + E_attr - 1) // E_attr | |
| edge_rep = edge_attr.repeat(reps, 1)[:E_idx] | |
| if edge_rep.size(1) != target_dim: | |
| D = edge_rep.size(1) | |
| if D < target_dim: | |
| pad = torch.zeros((E_idx, target_dim - D), dtype=torch.float, device=edge_rep.device) | |
| edge_rep = torch.cat([edge_rep, pad], dim=1) | |
| else: | |
| edge_rep = edge_rep[:, :target_dim] | |
| return edge_rep | |
| def parse_graphs_from_csv(csv_path: str, target_rows: int, chunksize: int): | |
| """ | |
| Stream CSV and parse the JSON 'graph' field into graph tensors needed by the model. | |
| Returns lists of per-graph tensors. | |
| """ | |
| node_atomic_lists = [] | |
| node_chirality_lists = [] | |
| node_charge_lists = [] | |
| edge_index_lists = [] | |
| edge_attr_lists = [] | |
| num_nodes_list = [] | |
| rows_read = 0 | |
| for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize): | |
| for _, row in chunk.iterrows(): | |
| graph_field = None | |
| if "graph" in row and not pd.isna(row["graph"]): | |
| try: | |
| graph_field = json.loads(row["graph"]) if isinstance(row["graph"], str) else row["graph"] | |
| except Exception: | |
| graph_field = None | |
| else: | |
| continue | |
| if graph_field is None: | |
| continue | |
| node_features = safe_get(graph_field, "node_features", None) | |
| if not node_features: | |
| continue | |
| atomic_nums = [] | |
| chirality_vals = [] | |
| formal_charges = [] | |
| for nf in node_features: | |
| an = safe_get(nf, "atomic_num", safe_get(nf, "atomic_number", 0)) | |
| ch = safe_get(nf, "chirality", 0) | |
| fc = safe_get(nf, "formal_charge", 0) | |
| atomic_nums.append(int(an)) | |
| chirality_vals.append(float(ch)) | |
| formal_charges.append(float(fc)) | |
| n_nodes = len(atomic_nums) | |
| edge_indices_raw = safe_get(graph_field, "edge_indices", None) | |
| edge_features_raw = safe_get(graph_field, "edge_features", None) | |
| if edge_indices_raw is None: | |
| adj_mat = safe_get(graph_field, "adjacency_matrix", None) | |
| if adj_mat: | |
| srcs, dsts = [], [] | |
| for i, row_adj in enumerate(adj_mat): | |
| for j, val in enumerate(row_adj): | |
| if val: | |
| srcs.append(i) | |
| dsts.append(j) | |
| edge_index = torch.tensor([srcs, dsts], dtype=torch.long) | |
| E = edge_index.size(1) | |
| edge_attr = torch.zeros((E, 3), dtype=torch.float) | |
| else: | |
| continue | |
| else: | |
| srcs, dsts = [], [] | |
| if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list): | |
| if all(len(pair) == 2 and isinstance(pair[0], int) for pair in edge_indices_raw): | |
| srcs = [int(p[0]) for p in edge_indices_raw] | |
| dsts = [int(p[1]) for p in edge_indices_raw] | |
| elif isinstance(edge_indices_raw[0][0], int): | |
| try: | |
| srcs = [int(x) for x in edge_indices_raw[0]] | |
| dsts = [int(x) for x in edge_indices_raw[1]] | |
| except Exception: | |
| srcs, dsts = [], [] | |
| if len(srcs) == 0: | |
| continue | |
| edge_index = torch.tensor([srcs, dsts], dtype=torch.long) | |
| if edge_features_raw and isinstance(edge_features_raw, list): | |
| bond_types, stereos, is_conjs = [], [], [] | |
| for ef in edge_features_raw: | |
| bt = safe_get(ef, "bond_type", 0) | |
| st = safe_get(ef, "stereo", 0) | |
| ic = safe_get(ef, "is_conjugated", False) | |
| bond_types.append(float(bt)) | |
| stereos.append(float(st)) | |
| is_conjs.append(float(1.0 if ic else 0.0)) | |
| edge_attr = torch.tensor(np.stack([bond_types, stereos, is_conjs], axis=1), dtype=torch.float) | |
| else: | |
| E = edge_index.size(1) | |
| edge_attr = torch.zeros((E, 3), dtype=torch.float) | |
| edge_attr = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3) | |
| node_atomic_lists.append(torch.tensor(atomic_nums, dtype=torch.long)) | |
| node_chirality_lists.append(torch.tensor(chirality_vals, dtype=torch.float)) | |
| node_charge_lists.append(torch.tensor(formal_charges, dtype=torch.float)) | |
| edge_index_lists.append(edge_index) | |
| edge_attr_lists.append(edge_attr) | |
| num_nodes_list.append(n_nodes) | |
| rows_read += 1 | |
| if rows_read >= target_rows: | |
| break | |
| if rows_read >= target_rows: | |
| break | |
| if len(node_atomic_lists) == 0: | |
| raise RuntimeError("No graphs were parsed from the CSV 'graph' column. Check input file and format.") | |
| print(f"Parsed {len(node_atomic_lists)} graphs (using 'graph' column). Using manual max atomic Z = {MAX_ATOMIC_Z}") | |
| return ( | |
| node_atomic_lists, | |
| node_chirality_lists, | |
| node_charge_lists, | |
| edge_index_lists, | |
| edge_attr_lists, | |
| num_nodes_list, | |
| ) | |
| def compute_class_weights(train_atomic: List[torch.Tensor]) -> torch.Tensor: | |
| """Compute inverse-frequency class weights for atomic number prediction.""" | |
| num_classes = MASK_ATOM_ID + 1 | |
| counts = np.ones((num_classes,), dtype=np.float64) | |
| for z in train_atomic: | |
| vals = z.cpu().numpy().astype(int) | |
| for v in vals: | |
| if 0 <= v < num_classes: | |
| counts[v] += 1.0 | |
| freq = counts / counts.sum() | |
| inv_freq = 1.0 / (freq + 1e-12) | |
| class_weights = inv_freq / inv_freq.mean() | |
| class_weights = torch.tensor(class_weights, dtype=torch.float) | |
| class_weights[MASK_ATOM_ID] = 1.0 | |
| return class_weights | |
| # ============================================================================= | |
| # Encoder wrapper used by MaskedGINE | |
| # ============================================================================= | |
| class GineBlock(nn.Module): | |
| """One GINEConv block (MLP + BN + ReLU).""" | |
| def __init__(self, node_dim: int): | |
| super().__init__() | |
| self.mlp = nn.Sequential(nn.Linear(node_dim, node_dim), nn.ReLU(), nn.Linear(node_dim, node_dim)) | |
| self.conv = GINEConv(self.mlp) | |
| self.bn = nn.BatchNorm1d(node_dim) | |
| self.act = nn.ReLU() | |
| def forward(self, x, edge_index, edge_attr): | |
| x = self.conv(x, edge_index, edge_attr) | |
| x = self.bn(x) | |
| x = self.act(x) | |
| return x | |
| class GineEncoder(nn.Module): | |
| """ | |
| Graph encoder: | |
| - Produces node embeddings via GINE | |
| - Provides pooled graph embedding via mean pooling + pool_proj | |
| - Provides node_logits(...) for reconstruction (atomic prediction head) | |
| """ | |
| def __init__( | |
| self, | |
| node_emb_dim: int = NODE_EMB_DIM, | |
| edge_emb_dim: int = EDGE_EMB_DIM, | |
| num_layers: int = NUM_GNN_LAYERS, | |
| max_atomic_z: int = MAX_ATOMIC_Z, | |
| emb_dim: int = 600, | |
| class_weights: Optional[torch.Tensor] = None, | |
| ): | |
| super().__init__() | |
| self.node_emb_dim = node_emb_dim | |
| self.edge_emb_dim = edge_emb_dim | |
| self.max_atomic_z = max_atomic_z | |
| num_embeddings = MASK_ATOM_ID + 1 | |
| self.atom_emb = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=node_emb_dim, padding_idx=None) | |
| self.node_attr_proj = nn.Sequential(nn.Linear(2, node_emb_dim), nn.ReLU(), nn.Linear(node_emb_dim, node_emb_dim)) | |
| self.edge_encoder = nn.Sequential(nn.Linear(3, edge_emb_dim), nn.ReLU(), nn.Linear(edge_emb_dim, edge_emb_dim)) | |
| self._edge_to_node_proj = nn.Linear(edge_emb_dim, node_emb_dim) if edge_emb_dim != node_emb_dim else None | |
| self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)]) | |
| # node head for masked-atom reconstruction | |
| self.atom_head = nn.Linear(node_emb_dim, MASK_ATOM_ID + 1) | |
| # pooled embedding projection | |
| self.pool_proj = nn.Linear(node_emb_dim, emb_dim) | |
| if class_weights is not None: | |
| self.register_buffer("class_weights", class_weights) | |
| else: | |
| self.class_weights = None | |
| def encode_nodes(self, z, chirality, formal_charge, edge_index, edge_attr): | |
| if z.numel() == 0: | |
| return torch.zeros((0, self.node_emb_dim), device=z.device) | |
| atom_embedding = self.atom_emb(z) | |
| node_attr = torch.stack([chirality, formal_charge], dim=1) | |
| node_attr_emb = self.node_attr_proj(node_attr.to(atom_embedding.device)) | |
| x = atom_embedding + node_attr_emb | |
| if edge_attr is None or edge_attr.numel() == 0: | |
| edge_emb = torch.zeros((0, self.edge_emb_dim), dtype=torch.float, device=x.device) | |
| else: | |
| edge_emb = self.edge_encoder(edge_attr.to(x.device)) | |
| edge_for_conv = self._edge_to_node_proj(edge_emb) if self._edge_to_node_proj is not None else edge_emb | |
| h = x | |
| for layer in self.gnn_layers: | |
| h = layer(h, edge_index.to(h.device), edge_for_conv) | |
| return h | |
| def node_logits(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None): | |
| h = self.encode_nodes(z, chirality, formal_charge, edge_index, edge_attr) | |
| return self.atom_head(h) | |
| def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None): | |
| """ | |
| Returns pooled graph embedding (B, emb_dim). | |
| Pool = mean over nodes per graph (batch vector). | |
| """ | |
| if batch is None: | |
| batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device) | |
| h = self.encode_nodes(z, chirality, formal_charge, edge_index, edge_attr) | |
| if h.size(0) == 0: | |
| # no nodes: return empty batch | |
| B = int(batch.max().item() + 1) if batch.numel() > 0 else 0 | |
| return torch.zeros((B, self.pool_proj.out_features), device=z.device) | |
| B = int(batch.max().item() + 1) if batch.numel() > 0 else 1 | |
| pooled = torch.zeros((B, h.size(1)), device=h.device) | |
| counts = torch.zeros((B,), device=h.device).clamp(min=0.0) | |
| pooled.index_add_(0, batch, h) | |
| ones = torch.ones((h.size(0),), device=h.device) | |
| counts.index_add_(0, batch, ones) | |
| pooled = pooled / counts.clamp(min=1.0).unsqueeze(-1) | |
| return self.pool_proj(pooled) | |
| # ============================================================================= | |
| # Training dataset + collate | |
| # ============================================================================= | |
| class PolymerDataset(Dataset): | |
| """Holds per-graph tensors; collation builds a single batched graph with masking targets.""" | |
| def __init__(self, atomic_list, chirality_list, charge_list, edge_index_list, edge_attr_list, num_nodes_list): | |
| self.atomic_list = atomic_list | |
| self.chirality_list = chirality_list | |
| self.charge_list = charge_list | |
| self.edge_index_list = edge_index_list | |
| self.edge_attr_list = edge_attr_list | |
| self.num_nodes_list = num_nodes_list | |
| def __len__(self): | |
| return len(self.atomic_list) | |
| def __getitem__(self, idx): | |
| return { | |
| "z": self.atomic_list[idx], | |
| "chirality": self.chirality_list[idx], | |
| "formal_charge": self.charge_list[idx], | |
| "edge_index": self.edge_index_list[idx], | |
| "edge_attr": self.edge_attr_list[idx], | |
| "num_nodes": int(self.num_nodes_list[idx]), | |
| } | |
| def collate_batch(batch): | |
| """ | |
| Build a single batched graph (node-concatenation with edge index offsets) and create: | |
| - masked node labels (labels_z) | |
| - hop-distance anchor targets (labels_dists) for masked nodes | |
| """ | |
| all_z, all_ch, all_fc = [], [], [] | |
| all_labels_z, all_labels_dists, all_labels_dists_mask = [], [], [] | |
| batch_idx = [] | |
| edge_index_list_batched = [] | |
| edge_attr_list_batched = [] | |
| node_offset = 0 | |
| for i, g in enumerate(batch): | |
| z = g["z"] | |
| n = z.size(0) | |
| if n == 0: | |
| continue | |
| chir = g["chirality"] | |
| fc = g["formal_charge"] | |
| edge_index = g["edge_index"] | |
| edge_attr = g["edge_attr"] | |
| is_selected = torch.rand(n) < P_MASK | |
| if is_selected.all(): | |
| is_selected[torch.randint(0, n, (1,))] = False | |
| labels_z = torch.full((n,), -100, dtype=torch.long) | |
| labels_dists = torch.zeros((n, K_ANCHORS), dtype=torch.float) | |
| labels_dists_mask = torch.zeros((n, K_ANCHORS), dtype=torch.bool) | |
| labels_z[is_selected] = z[is_selected] | |
| # BERT-style corruption on atomic numbers | |
| z_masked = z.clone() | |
| if is_selected.any(): | |
| sel_idx = torch.nonzero(is_selected).squeeze(-1) | |
| rand_atomic = torch.randint(1, MAX_ATOMIC_Z + 1, (sel_idx.size(0),), dtype=torch.long) | |
| probs = torch.rand(sel_idx.size(0)) | |
| mask_choice = probs < 0.8 | |
| rand_choice = (probs >= 0.8) & (probs < 0.9) | |
| if mask_choice.any(): | |
| z_masked[sel_idx[mask_choice]] = MASK_ATOM_ID | |
| if rand_choice.any(): | |
| z_masked[sel_idx[rand_choice]] = rand_atomic[rand_choice] | |
| # Hop-distance targets for masked atoms | |
| visible_idx = torch.nonzero(~is_selected).squeeze(-1) | |
| if visible_idx.numel() == 0: | |
| visible_idx = torch.arange(n, dtype=torch.long) | |
| dist_mat = shortest_path_lengths_hops(edge_index.clone(), n) | |
| for a in torch.nonzero(is_selected).squeeze(-1).tolist(): | |
| vis = visible_idx.numpy() | |
| if vis.size == 0: | |
| continue | |
| dists = dist_mat[a, vis].astype(np.float32) | |
| valid_mask = dists <= n | |
| if not valid_mask.any(): | |
| continue | |
| dists_valid = dists[valid_mask] | |
| k = min(K_ANCHORS, dists_valid.size) | |
| idx_sorted = np.argsort(dists_valid)[:k] | |
| labels_dists[a, :k] = torch.tensor(dists_valid[idx_sorted], dtype=torch.float) | |
| labels_dists_mask[a, :k] = True | |
| all_z.append(z_masked) | |
| all_ch.append(chir) | |
| all_fc.append(fc) | |
| all_labels_z.append(labels_z) | |
| all_labels_dists.append(labels_dists) | |
| all_labels_dists_mask.append(labels_dists_mask) | |
| batch_idx.append(torch.full((n,), i, dtype=torch.long)) | |
| if edge_index is not None and edge_index.numel() > 0: | |
| ei_offset = edge_index + node_offset | |
| edge_index_list_batched.append(ei_offset) | |
| edge_attr_matched = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3) | |
| edge_attr_list_batched.append(edge_attr_matched) | |
| node_offset += n | |
| if len(all_z) == 0: | |
| return { | |
| "z": torch.tensor([], dtype=torch.long), | |
| "chirality": torch.tensor([], dtype=torch.float), | |
| "formal_charge": torch.tensor([], dtype=torch.float), | |
| "edge_index": torch.tensor([[], []], dtype=torch.long), | |
| "edge_attr": torch.tensor([], dtype=torch.float).reshape(0, 3), | |
| "batch": torch.tensor([], dtype=torch.long), | |
| "labels_z": torch.tensor([], dtype=torch.long), | |
| "labels_dists": torch.tensor([], dtype=torch.float).reshape(0, K_ANCHORS), | |
| "labels_dists_mask": torch.tensor([], dtype=torch.bool).reshape(0, K_ANCHORS), | |
| } | |
| z_batch = torch.cat(all_z, dim=0) | |
| chir_batch = torch.cat(all_ch, dim=0) | |
| fc_batch = torch.cat(all_fc, dim=0) | |
| labels_z_batch = torch.cat(all_labels_z, dim=0) | |
| labels_dists_batch = torch.cat(all_labels_dists, dim=0) | |
| labels_dists_mask_batch = torch.cat(all_labels_dists_mask, dim=0) | |
| batch_batch = torch.cat(batch_idx, dim=0) | |
| if len(edge_index_list_batched) > 0: | |
| edge_index_batched = torch.cat(edge_index_list_batched, dim=1) | |
| edge_attr_batched = torch.cat(edge_attr_list_batched, dim=0) | |
| else: | |
| edge_index_batched = torch.tensor([[], []], dtype=torch.long) | |
| edge_attr_batched = torch.tensor([], dtype=torch.float).reshape(0, 3) | |
| return { | |
| "z": z_batch, | |
| "chirality": chir_batch, | |
| "formal_charge": fc_batch, | |
| "edge_index": edge_index_batched, | |
| "edge_attr": edge_attr_batched, | |
| "batch": batch_batch, | |
| "labels_z": labels_z_batch, | |
| "labels_dists": labels_dists_batch, | |
| "labels_dists_mask": labels_dists_mask_batch, | |
| } | |
| # ============================================================================= | |
| # Masked pretraining model | |
| # ============================================================================= | |
| class MaskedGINE(nn.Module): | |
| """ | |
| Masked GNN objective: | |
| - predict masked atomic numbers (classification head) | |
| - predict hop-distance anchors for masked nodes (regression head) | |
| - optionally learned uncertainty weighting across the two losses | |
| """ | |
| def __init__( | |
| self, | |
| node_emb_dim=NODE_EMB_DIM, | |
| edge_emb_dim=EDGE_EMB_DIM, | |
| num_layers=NUM_GNN_LAYERS, | |
| max_atomic_z=MAX_ATOMIC_Z, | |
| class_weights=None, | |
| ): | |
| super().__init__() | |
| # Use GineEncoder internally | |
| self.encoder = GineEncoder( | |
| node_emb_dim=node_emb_dim, | |
| edge_emb_dim=edge_emb_dim, | |
| num_layers=num_layers, | |
| max_atomic_z=max_atomic_z, | |
| emb_dim=600, | |
| class_weights=class_weights, | |
| ) | |
| # reuse same heads conceptually: | |
| # encoder has atom_head already; we add hop-distance head here | |
| self.coord_head = nn.Linear(node_emb_dim, K_ANCHORS) | |
| if USE_LEARNED_WEIGHTING: | |
| self.log_var_z = nn.Parameter(torch.zeros(1)) | |
| self.log_var_pos = nn.Parameter(torch.zeros(1)) | |
| else: | |
| self.log_var_z = None | |
| self.log_var_pos = None | |
| # class_weights | |
| self.class_weights = getattr(self.encoder, "class_weights", None) | |
| def forward( | |
| self, | |
| z, | |
| chirality, | |
| formal_charge, | |
| edge_index, | |
| edge_attr, | |
| batch=None, | |
| labels_z=None, | |
| labels_dists=None, | |
| labels_dists_mask=None, | |
| ): | |
| if batch is None: | |
| batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device) | |
| # node embeddings | |
| h = self.encoder.encode_nodes(z, chirality, formal_charge, edge_index, edge_attr) | |
| logits = self.encoder.atom_head(h) | |
| dists_pred = self.coord_head(h) | |
| if labels_z is not None and labels_dists is not None and labels_dists_mask is not None: | |
| mask = labels_z != -100 | |
| if mask.sum() == 0: | |
| return torch.tensor(0.0, device=z.device) | |
| logits_masked = logits[mask] | |
| dists_pred_masked = dists_pred[mask] | |
| labels_z_masked = labels_z[mask] | |
| labels_dists_masked = labels_dists[mask] | |
| labels_dists_mask_mask = labels_dists_mask[mask] | |
| if self.class_weights is not None: | |
| loss_z = F.cross_entropy( | |
| logits_masked, | |
| labels_z_masked.to(logits_masked.device), | |
| weight=self.class_weights.to(logits_masked.device), | |
| ) | |
| else: | |
| loss_z = F.cross_entropy(logits_masked, labels_z_masked.to(logits_masked.device)) | |
| if labels_dists_mask_mask.any(): | |
| preds = dists_pred_masked[labels_dists_mask_mask] | |
| trues = labels_dists_masked[labels_dists_mask_mask].to(preds.device) | |
| loss_pos = F.mse_loss(preds, trues, reduction="mean") | |
| else: | |
| loss_pos = torch.tensor(0.0, device=z.device) | |
| if USE_LEARNED_WEIGHTING: | |
| lz = torch.exp(-self.log_var_z) * loss_z + self.log_var_z | |
| lp = torch.exp(-self.log_var_pos) * loss_pos + self.log_var_pos | |
| return 0.5 * (lz + lp) | |
| return loss_z + loss_pos | |
| return logits, dists_pred | |
| class ValLossCallback(TrainerCallback): | |
| """Evaluation callback: prints metrics, saves best model, and early-stops on val loss.""" | |
| def __init__(self, best_model_dir: str, val_loader: DataLoader, patience: int = 10, trainer_ref=None): | |
| self.best_val_loss = float("inf") | |
| self.epochs_no_improve = 0 | |
| self.patience = patience | |
| self.best_epoch = None | |
| self.trainer_ref = trainer_ref | |
| self.best_model_dir = best_model_dir | |
| self.val_loader = val_loader | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| epoch_num = int(state.epoch) | |
| train_loss = next((x["loss"] for x in reversed(state.log_history) if "loss" in x), None) | |
| print(f"\n=== Epoch {epoch_num}/{args.num_train_epochs} ===") | |
| if train_loss is not None: | |
| print(f"Train Loss: {train_loss:.4f}") | |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): | |
| epoch_num = int(state.epoch) + 1 | |
| if self.trainer_ref is None: | |
| print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}") | |
| return | |
| metric_val_loss = metrics.get("eval_loss") if metrics is not None else None | |
| model_eval = self.trainer_ref.model | |
| model_eval.eval() | |
| device_local = next(model_eval.parameters()).device | |
| preds_z_all, true_z_all = [], [] | |
| pred_dists_all, true_dists_all = [], [] | |
| total_loss, n_batches = 0.0, 0 | |
| logits_masked_list, labels_masked_list = [], [] | |
| with torch.no_grad(): | |
| for batch in self.val_loader: | |
| z = batch["z"].to(device_local) | |
| chir = batch["chirality"].to(device_local) | |
| fc = batch["formal_charge"].to(device_local) | |
| edge_index = batch["edge_index"].to(device_local) | |
| edge_attr = batch["edge_attr"].to(device_local) | |
| batch_idx = batch["batch"].to(device_local) | |
| labels_z = batch["labels_z"].to(device_local) | |
| labels_dists = batch["labels_dists"].to(device_local) | |
| labels_dists_mask = batch["labels_dists_mask"].to(device_local) | |
| try: | |
| loss = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx, labels_z, labels_dists, labels_dists_mask) | |
| except Exception: | |
| loss = None | |
| if isinstance(loss, torch.Tensor): | |
| total_loss += loss.item() | |
| n_batches += 1 | |
| logits, dists_pred = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx) | |
| mask = labels_z != -100 | |
| if mask.sum().item() == 0: | |
| continue | |
| logits_masked_list.append(logits[mask]) | |
| labels_masked_list.append(labels_z[mask]) | |
| pred_z = torch.argmax(logits[mask], dim=-1) | |
| true_z = labels_z[mask] | |
| pred_d = dists_pred[mask][labels_dists_mask[mask]] | |
| true_d = labels_dists[mask][labels_dists_mask[mask]] | |
| if pred_d.numel() > 0: | |
| pred_dists_all.extend(pred_d.cpu().tolist()) | |
| true_dists_all.extend(true_d.cpu().tolist()) | |
| preds_z_all.extend(pred_z.cpu().tolist()) | |
| true_z_all.extend(true_z.cpu().tolist()) | |
| avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan")) | |
| accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0 | |
| f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0 | |
| rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0 | |
| mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0 | |
| if len(logits_masked_list) > 0: | |
| all_logits_masked = torch.cat(logits_masked_list, dim=0) | |
| all_labels_masked = torch.cat(labels_masked_list, dim=0) | |
| cw = getattr(model_eval, "class_weights", None) | |
| if cw is not None: | |
| try: | |
| loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked, weight=cw.to(device_local)) | |
| except Exception: | |
| loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked) | |
| else: | |
| loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked) | |
| try: | |
| perplexity = float(torch.exp(loss_z_all).cpu().item()) | |
| except Exception: | |
| perplexity = float(np.exp(float(loss_z_all.cpu().item()))) | |
| else: | |
| perplexity = float("nan") | |
| print(f"\n--- Evaluation after Epoch {epoch_num} ---") | |
| print(f"Validation Loss: {avg_val_loss:.4f}") | |
| print(f"Validation Accuracy: {accuracy:.4f}") | |
| print(f"Validation F1 (weighted): {f1:.4f}") | |
| print(f"Validation RMSE (distances): {rmse:.4f}") | |
| print(f"Validation MAE (distances): {mae:.4f}") | |
| print(f"Validation Perplexity (classification head): {perplexity:.4f}") | |
| if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6: | |
| self.best_val_loss = avg_val_loss | |
| self.best_epoch = int(state.epoch) | |
| self.epochs_no_improve = 0 | |
| os.makedirs(self.best_model_dir, exist_ok=True) | |
| try: | |
| torch.save(self.trainer_ref.model.state_dict(), os.path.join(self.best_model_dir, "pytorch_model.bin")) | |
| print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(self.best_model_dir, 'pytorch_model.bin')}") | |
| except Exception as e: | |
| print(f"Failed to save best model at epoch {epoch_num}: {e}") | |
| else: | |
| self.epochs_no_improve += 1 | |
| if self.epochs_no_improve >= self.patience: | |
| print(f"Early stopping after {self.patience} epochs with no improvement.") | |
| control.should_training_stop = True | |
| def build_datasets_and_loaders(parsed, batch_train: int = 16, batch_val: int = 8, num_workers: int = 4): | |
| """Split indices into train/val and construct Dataset/DataLoader.""" | |
| (node_atomic_lists, node_chirality_lists, node_charge_lists, edge_index_lists, edge_attr_lists, num_nodes_list) = parsed | |
| indices = list(range(len(node_atomic_lists))) | |
| train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42) | |
| def subset(l, idxs): | |
| return [l[i] for i in idxs] | |
| train_atomic = subset(node_atomic_lists, train_idx) | |
| train_chirality = subset(node_chirality_lists, train_idx) | |
| train_charge = subset(node_charge_lists, train_idx) | |
| train_edge_index = subset(edge_index_lists, train_idx) | |
| train_edge_attr = subset(edge_attr_lists, train_idx) | |
| train_num_nodes = subset(num_nodes_list, train_idx) | |
| val_atomic = subset(node_atomic_lists, val_idx) | |
| val_chirality = subset(node_chirality_lists, val_idx) | |
| val_charge = subset(node_charge_lists, val_idx) | |
| val_edge_index = subset(edge_index_lists, val_idx) | |
| val_edge_attr = subset(edge_attr_lists, val_idx) | |
| val_num_nodes = subset(num_nodes_list, val_idx) | |
| train_dataset = PolymerDataset(train_atomic, train_chirality, train_charge, train_edge_index, train_edge_attr, train_num_nodes) | |
| val_dataset = PolymerDataset(val_atomic, val_chirality, val_charge, val_edge_index, val_edge_attr, val_num_nodes) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_train, shuffle=True, collate_fn=collate_batch, num_workers=num_workers) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_val, shuffle=False, collate_fn=collate_batch, num_workers=num_workers) | |
| return train_dataset, val_dataset, train_loader, val_loader, train_atomic | |
| def train_and_evaluate(args: argparse.Namespace) -> None: | |
| """Main run: parse data, build model, train, reload best, final eval printout.""" | |
| output_dir = args.output_dir | |
| best_model_dir = os.path.join(output_dir, "best") | |
| os.makedirs(output_dir, exist_ok=True) | |
| parsed = parse_graphs_from_csv(args.csv_path, args.target_rows, args.chunksize) | |
| train_dataset, val_dataset, train_loader, val_loader, train_atomic = build_datasets_and_loaders( | |
| parsed, batch_train=16, batch_val=8, num_workers=args.num_workers | |
| ) | |
| class_weights = compute_class_weights(train_atomic) | |
| model = MaskedGINE( | |
| node_emb_dim=NODE_EMB_DIM, | |
| edge_emb_dim=EDGE_EMB_DIM, | |
| num_layers=NUM_GNN_LAYERS, | |
| max_atomic_z=MAX_ATOMIC_Z, | |
| class_weights=class_weights, | |
| ) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| overwrite_output_dir=True, | |
| num_train_epochs=25, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=8, | |
| gradient_accumulation_steps=4, | |
| eval_strategy="epoch", | |
| logging_steps=500, | |
| learning_rate=1e-4, | |
| weight_decay=0.01, | |
| fp16=torch.cuda.is_available(), | |
| save_strategy="no", | |
| disable_tqdm=False, | |
| logging_first_step=True, | |
| report_to=[], | |
| dataloader_num_workers=args.num_workers, | |
| ) | |
| callback = ValLossCallback(best_model_dir=best_model_dir, val_loader=val_loader, patience=10) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| data_collator=collate_batch, | |
| callbacks=[callback], | |
| ) | |
| callback.trainer_ref = trainer | |
| start_time = time.time() | |
| trainer.train() | |
| total_time = time.time() - start_time | |
| best_model_path = os.path.join(best_model_dir, "pytorch_model.bin") | |
| if os.path.exists(best_model_path): | |
| try: | |
| model.load_state_dict(torch.load(best_model_path, map_location=device)) | |
| print(f"\nLoaded best model from {best_model_path}") | |
| except Exception as e: | |
| print(f"\nFailed to load best model from {best_model_path}: {e}") | |
| # Final evaluation | |
| model.eval() | |
| preds_z_all, true_z_all = [], [] | |
| pred_dists_all, true_dists_all = [], [] | |
| logits_masked_list_final, labels_masked_list_final = [], [] | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| z = batch["z"].to(device) | |
| chir = batch["chirality"].to(device) | |
| fc = batch["formal_charge"].to(device) | |
| edge_index = batch["edge_index"].to(device) | |
| edge_attr = batch["edge_attr"].to(device) | |
| batch_idx = batch["batch"].to(device) | |
| labels_z = batch["labels_z"].to(device) | |
| labels_dists = batch["labels_dists"].to(device) | |
| labels_dists_mask = batch["labels_dists_mask"].to(device) | |
| logits, dists_pred = model(z, chir, fc, edge_index, edge_attr, batch_idx) | |
| mask = labels_z != -100 | |
| if mask.sum().item() == 0: | |
| continue | |
| logits_masked_list_final.append(logits[mask]) | |
| labels_masked_list_final.append(labels_z[mask]) | |
| pred_z = torch.argmax(logits[mask], dim=-1) | |
| true_z = labels_z[mask] | |
| pred_d = dists_pred[mask][labels_dists_mask[mask]] | |
| true_d = labels_dists[mask][labels_dists_mask[mask]] | |
| if pred_d.numel() > 0: | |
| pred_dists_all.extend(pred_d.cpu().tolist()) | |
| true_dists_all.extend(true_d.cpu().tolist()) | |
| preds_z_all.extend(pred_z.cpu().tolist()) | |
| true_z_all.extend(true_z.cpu().tolist()) | |
| accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0 | |
| f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0 | |
| rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0 | |
| mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0 | |
| if len(logits_masked_list_final) > 0: | |
| all_logits_masked_final = torch.cat(logits_masked_list_final, dim=0) | |
| all_labels_masked_final = torch.cat(labels_masked_list_final, dim=0) | |
| cw_final = getattr(model, "class_weights", None) | |
| if cw_final is not None: | |
| try: | |
| loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final, weight=cw_final.to(device)) | |
| except Exception: | |
| loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final) | |
| else: | |
| loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final) | |
| try: | |
| perplexity_final = float(torch.exp(loss_z_final).cpu().item()) | |
| except Exception: | |
| perplexity_final = float(np.exp(float(loss_z_final.cpu().item()))) | |
| else: | |
| perplexity_final = float("nan") | |
| best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan") | |
| best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None | |
| print(f"\n=== Final Results (evaluated on best saved model) ===") | |
| print(f"Total Training Time (s): {total_time:.2f}") | |
| print(f"Best Epoch (1-based): {best_epoch_num}" if best_epoch_num is not None else "Best Epoch: (none saved)") | |
| print(f"Best Validation Loss: {best_val_loss:.4f}") | |
| print(f"Validation Accuracy: {accuracy:.4f}") | |
| print(f"Validation F1 (weighted): {f1:.4f}") | |
| print(f"Validation RMSE (distances): {rmse:.4f}") | |
| print(f"Validation MAE (distances): {mae:.4f}") | |
| print(f"Validation Perplexity (classification head): {perplexity_final:.4f}") | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| non_trainable_params = total_params - trainable_params | |
| print(f"Total Parameters: {total_params}") | |
| print(f"Trainable Parameters: {trainable_params}") | |
| print(f"Non-trainable Parameters: {non_trainable_params}") | |
| def main(): | |
| args = parse_args() | |
| train_and_evaluate(args) | |
| if __name__ == "__main__": | |
| main() | |