Spaces:
Running
Running
| """ | |
| Transformer.py | |
| Fingerprint masked language modeling (MLM) using a Transformer encoder. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import time | |
| import sys | |
| import csv | |
| import argparse | |
| from typing import List, Optional | |
| # Increase max CSV field size limit (fingerprints can be long) | |
| csv.field_size_limit(sys.maxsize) | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import TrainingArguments, Trainer | |
| from transformers.trainer_callback import TrainerCallback | |
| from sklearn.metrics import accuracy_score, f1_score | |
| # --------------------------- | |
| # Configuration / Constants | |
| # --------------------------- | |
| P_MASK = 0.15 | |
| FINGERPRINT_KEY = "morgan_r3_bits" | |
| FP_LENGTH = 2048 | |
| MASK_TOKEN_ID = 2 | |
| VOCAB_SIZE = 3 | |
| HIDDEN_DIM = 256 | |
| TRANSFORMER_NUM_LAYERS = 4 | |
| TRANSFORMER_NHEAD = 8 | |
| TRANSFORMER_FF = 1024 | |
| DROPOUT = 0.1 | |
| TRAIN_BATCH_SIZE = 16 | |
| EVAL_BATCH_SIZE = 8 | |
| GRADIENT_ACCUMULATION_STEPS = 4 | |
| NUM_EPOCHS = 25 | |
| LEARNING_RATE = 1e-4 | |
| WEIGHT_DECAY = 0.01 | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Fingerprint MLM pretraining (Transformer).") | |
| parser.add_argument( | |
| "--csv_path", | |
| type=str, | |
| default="/path/to/polymer_structures_unified_processed.csv", | |
| help="Processed CSV containing a JSON 'fingerprints' column.", | |
| ) | |
| parser.add_argument("--target_rows", type=int, default=5_000_000, help="Max rows to parse.") | |
| parser.add_argument("--chunksize", type=int, default=50_000, help="CSV chunksize.") | |
| parser.add_argument("--output_dir", type=str, default="/path/to/fingerprint_mlm_output_5M", help="Training output directory.") | |
| parser.add_argument("--num_workers", type=int, default=0, help="PyTorch DataLoader num workers (kept default 0).") | |
| return parser.parse_args() | |
| def load_fingerprints(csv_path: str, target_rows: int, chunksize: int) -> List[List[int]]: | |
| """Stream CSV and parse fingerprint bits into fixed-length vectors of ints.""" | |
| fp_lists: List[List[int]] = [] | |
| rows_read = 0 | |
| for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize): | |
| fps_chunk = chunk["fingerprints"] | |
| for fpval in fps_chunk: | |
| if pd.isna(fpval): | |
| fp_lists.append([0] * FP_LENGTH) | |
| continue | |
| if isinstance(fpval, str): | |
| try: | |
| fp_json = json.loads(fpval) | |
| except Exception: | |
| try: | |
| fp_json = json.loads(fpval.replace("'", '"')) | |
| except Exception: | |
| parts = [p.strip().strip('"').strip("'") for p in fpval.split(",")] | |
| bits = [1 if p in ("1", "True", "true") else 0 for p in parts[:FP_LENGTH]] | |
| if len(bits) < FP_LENGTH: | |
| bits += [0] * (FP_LENGTH - len(bits)) | |
| fp_lists.append(bits) | |
| continue | |
| elif isinstance(fpval, dict): | |
| fp_json = fpval | |
| else: | |
| fp_lists.append([0] * FP_LENGTH) | |
| continue | |
| bits = fp_json.get(FINGERPRINT_KEY, None) | |
| if bits is None: | |
| if isinstance(fp_json, list): | |
| bits = fp_json | |
| else: | |
| bits = [0] * FP_LENGTH | |
| normalized = [] | |
| for b in bits: | |
| if isinstance(b, str): | |
| b_clean = b.strip().strip('"').strip("'") | |
| normalized.append(1 if b_clean in ("1", "True", "true") else 0) | |
| elif isinstance(b, (int, np.integer)): | |
| normalized.append(1 if int(b) != 0 else 0) | |
| else: | |
| normalized.append(0) | |
| if len(normalized) >= FP_LENGTH: | |
| break | |
| if len(normalized) < FP_LENGTH: | |
| normalized.extend([0] * (FP_LENGTH - len(normalized))) | |
| fp_lists.append(normalized[:FP_LENGTH]) | |
| rows_read += len(chunk) | |
| if rows_read >= target_rows: | |
| break | |
| print(f"Loaded {len(fp_lists)} fingerprint vectors (using FP_LENGTH={FP_LENGTH}).") | |
| return fp_lists | |
| class FingerprintDataset(Dataset): | |
| """Dataset of fixed-length fingerprint bit vectors (stored as torch.long tensors).""" | |
| def __init__(self, fps: List[torch.Tensor]): | |
| self.fps = fps | |
| def __len__(self): | |
| return len(self.fps) | |
| def __getitem__(self, idx): | |
| return self.fps[idx] | |
| def collate_batch(batch): | |
| """ | |
| MLM-style collation: | |
| - Select positions with P_MASK | |
| - Labels are true bits only on selected positions, else -100 | |
| - Inputs are corrupted with 80/10/10 mask/random/keep policy | |
| """ | |
| B = len(batch) | |
| if B == 0: | |
| return { | |
| "input_ids": torch.zeros((0, FP_LENGTH), dtype=torch.long), | |
| "labels": torch.zeros((0, FP_LENGTH), dtype=torch.long), | |
| "attention_mask": torch.zeros((0, FP_LENGTH), dtype=torch.bool), | |
| } | |
| tensors = [] | |
| for item in batch: | |
| if isinstance(item, torch.Tensor): | |
| tensors.append(item) | |
| else: | |
| tensors.append(torch.tensor(item, dtype=torch.long)) | |
| all_inputs = torch.stack(tensors, dim=0).long() | |
| labels = torch.full_like(all_inputs, fill_value=-100, dtype=torch.long) | |
| z_masked = all_inputs.clone() | |
| for i in range(B): | |
| z = all_inputs[i] | |
| n_positions = z.size(0) | |
| is_selected = torch.rand(n_positions) < P_MASK | |
| if is_selected.all(): | |
| is_selected[torch.randint(0, n_positions, (1,))] = False | |
| sel_idx = torch.nonzero(is_selected).squeeze(-1) | |
| if sel_idx.numel() > 0: | |
| labels[i, sel_idx] = z[sel_idx] | |
| probs = torch.rand(sel_idx.size(0)) | |
| mask_choice = probs < 0.8 | |
| rand_choice = (probs >= 0.8) & (probs < 0.9) | |
| if mask_choice.any(): | |
| z_masked[i, sel_idx[mask_choice]] = MASK_TOKEN_ID | |
| if rand_choice.any(): | |
| rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long) | |
| z_masked[i, sel_idx[rand_choice]] = rand_bits | |
| attention_mask = torch.ones_like(all_inputs, dtype=torch.bool) | |
| return {"input_ids": z_masked, "labels": labels, "attention_mask": attention_mask} | |
| class FingerprintEncoder(nn.Module): | |
| """Transformer encoder over a length-FP_LENGTH token sequence with small vocab {0,1,MASK}.""" | |
| def __init__( | |
| self, | |
| vocab_size=VOCAB_SIZE, | |
| hidden_dim=HIDDEN_DIM, | |
| seq_len=FP_LENGTH, | |
| num_layers=TRANSFORMER_NUM_LAYERS, | |
| nhead=TRANSFORMER_NHEAD, | |
| dim_feedforward=TRANSFORMER_FF, | |
| dropout=DROPOUT, | |
| ): | |
| super().__init__() | |
| self.token_emb = nn.Embedding(vocab_size, hidden_dim) | |
| self.pos_emb = nn.Embedding(seq_len, hidden_dim) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=hidden_dim, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| batch_first=True, | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| def forward(self, input_ids, attention_mask=None): | |
| B, L = input_ids.shape | |
| x = self.token_emb(input_ids) | |
| pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1) | |
| x = x + self.pos_emb(pos_ids) | |
| key_padding_mask = (~attention_mask) if attention_mask is not None else None | |
| return self.transformer(x, src_key_padding_mask=key_padding_mask) | |
| # ============================================================================= | |
| # Wrapper used for MLM training | |
| # ============================================================================= | |
| class PooledFingerprintEncoder(nn.Module): | |
| """ | |
| Dual-use: | |
| - labels is None -> return pooled embedding (B, emb_dim) | |
| - labels provided -> return loss scalar [Trainer-compatible MLM] | |
| Also provides token_logits(...) used for reconstruction. | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size=VOCAB_SIZE, | |
| hidden_dim=HIDDEN_DIM, | |
| seq_len=FP_LENGTH, | |
| num_layers=TRANSFORMER_NUM_LAYERS, | |
| nhead=TRANSFORMER_NHEAD, | |
| dim_feedforward=TRANSFORMER_FF, | |
| dropout=DROPOUT, | |
| emb_dim: int = 600, | |
| ): | |
| super().__init__() | |
| self.encoder = FingerprintEncoder( | |
| vocab_size=vocab_size, | |
| hidden_dim=hidden_dim, | |
| seq_len=seq_len, | |
| num_layers=num_layers, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| ) | |
| self.mlm_head = nn.Linear(hidden_dim, vocab_size) | |
| self.pool_proj = nn.Linear(hidden_dim, emb_dim) | |
| def _pool(self, h, attention_mask=None): | |
| if attention_mask is None: | |
| return h.mean(dim=1) | |
| mask = attention_mask.unsqueeze(-1).float() | |
| denom = mask.sum(dim=1).clamp(min=1.0) | |
| return (h * mask).sum(dim=1) / denom | |
| def token_logits(self, input_ids, attention_mask=None): | |
| h = self.encoder(input_ids, attention_mask=attention_mask) | |
| return self.mlm_head(h) | |
| def forward(self, input_ids, attention_mask=None, labels=None): | |
| logits = self.token_logits(input_ids, attention_mask=attention_mask) | |
| if labels is not None: | |
| mask = labels != -100 | |
| if mask.sum() == 0: | |
| return torch.tensor(0.0, device=input_ids.device) | |
| logits_masked = logits[mask] | |
| labels_masked = labels[mask].long() | |
| return F.cross_entropy(logits_masked, labels_masked) | |
| # pooled embedding for CL | |
| h = self.encoder(input_ids, attention_mask=attention_mask) | |
| pooled = self._pool(h, attention_mask=attention_mask) | |
| return self.pool_proj(pooled) | |
| class ValLossCallback(TrainerCallback): | |
| """Tracks best eval loss, prints metrics, saves best model, early-stops.""" | |
| def __init__(self, best_model_dir: str, val_loader: DataLoader, patience: int = 10, trainer_ref=None): | |
| self.best_val_loss = float("inf") | |
| self.epochs_no_improve = 0 | |
| self.patience = patience | |
| self.best_epoch = None | |
| self.trainer_ref = trainer_ref | |
| self.best_model_dir = best_model_dir | |
| self.val_loader = val_loader | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| epoch_num = int(state.epoch) | |
| train_loss = next((x["loss"] for x in reversed(state.log_history) if "loss" in x), None) | |
| print(f"\n=== Epoch {epoch_num}/{args.num_train_epochs} ===") | |
| if train_loss is not None: | |
| print(f"Train Loss: {train_loss:.4f}") | |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): | |
| epoch_num = int(state.epoch) + 1 | |
| if self.trainer_ref is None: | |
| print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}") | |
| return | |
| metric_val_loss = metrics.get("eval_loss") if metrics is not None else None | |
| model_eval = self.trainer_ref.model | |
| model_eval.eval() | |
| device_local = next(model_eval.parameters()).device | |
| preds_bits, true_bits = [], [] | |
| total_loss, n_batches = 0.0, 0 | |
| logits_masked_list, labels_masked_list = [], [] | |
| with torch.no_grad(): | |
| for batch in self.val_loader: | |
| input_ids = batch["input_ids"].to(device_local) | |
| labels = batch["labels"].to(device_local) | |
| attention_mask = batch.get("attention_mask", torch.ones_like(input_ids, dtype=torch.bool)).to(device_local) | |
| try: | |
| loss = model_eval(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| except Exception: | |
| loss = None | |
| if isinstance(loss, torch.Tensor): | |
| total_loss += loss.item() | |
| n_batches += 1 | |
| logits = model_eval.token_logits(input_ids=input_ids, attention_mask=attention_mask) | |
| mask = labels != -100 | |
| if mask.sum().item() == 0: | |
| continue | |
| logits_masked_list.append(logits[mask]) | |
| labels_masked_list.append(labels[mask]) | |
| pred_bits = torch.argmax(logits[mask], dim=-1) | |
| true_b = labels[mask] | |
| preds_bits.extend(pred_bits.cpu().tolist()) | |
| true_bits.extend(true_b.cpu().tolist()) | |
| avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan")) | |
| accuracy = accuracy_score(true_bits, preds_bits) if len(true_bits) > 0 else 0.0 | |
| f1 = f1_score(true_bits, preds_bits, average="weighted") if len(true_bits) > 0 else 0.0 | |
| if len(logits_masked_list) > 0: | |
| all_logits_masked = torch.cat(logits_masked_list, dim=0) | |
| all_labels_masked = torch.cat(labels_masked_list, dim=0) | |
| loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked.long()) | |
| try: | |
| perplexity = float(torch.exp(loss_z_all).cpu().item()) | |
| except Exception: | |
| perplexity = float(np.exp(float(loss_z_all.cpu().item()))) | |
| else: | |
| perplexity = float("nan") | |
| print(f"\n--- Evaluation after Epoch {epoch_num} ---") | |
| print(f"Validation Loss: {avg_val_loss:.4f}") | |
| print(f"Validation Accuracy: {accuracy:.4f}") | |
| print(f"Validation F1 (weighted): {f1:.4f}") | |
| print(f"Validation Perplexity (classification head): {perplexity:.4f}") | |
| if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6: | |
| self.best_val_loss = avg_val_loss | |
| self.best_epoch = int(state.epoch) | |
| self.epochs_no_improve = 0 | |
| os.makedirs(self.best_model_dir, exist_ok=True) | |
| try: | |
| torch.save(self.trainer_ref.model.state_dict(), os.path.join(self.best_model_dir, "pytorch_model.bin")) | |
| print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(self.best_model_dir, 'pytorch_model.bin')}") | |
| except Exception as e: | |
| print(f"Failed to save best model at epoch {epoch_num}: {e}") | |
| else: | |
| self.epochs_no_improve += 1 | |
| if self.epochs_no_improve >= self.patience: | |
| print(f"Early stopping after {self.patience} epochs with no improvement.") | |
| control.should_training_stop = True | |
| def train_and_eval(args: argparse.Namespace) -> None: | |
| output_dir = args.output_dir | |
| best_model_dir = os.path.join(output_dir, "best") | |
| os.makedirs(output_dir, exist_ok=True) | |
| fp_lists = load_fingerprints(args.csv_path, args.target_rows, args.chunksize) | |
| train_idx, val_idx = train_test_split(list(range(len(fp_lists))), test_size=0.2, random_state=42) | |
| train_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in train_idx] | |
| val_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in val_idx] | |
| train_dataset = FingerprintDataset(train_fps) | |
| val_dataset = FingerprintDataset(val_fps) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=TRAIN_BATCH_SIZE, | |
| shuffle=True, | |
| collate_fn=collate_batch, | |
| drop_last=False, | |
| num_workers=args.num_workers, | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=EVAL_BATCH_SIZE, | |
| shuffle=False, | |
| collate_fn=collate_batch, | |
| drop_last=False, | |
| num_workers=args.num_workers, | |
| ) | |
| model = PooledFingerprintEncoder( | |
| vocab_size=VOCAB_SIZE, | |
| hidden_dim=HIDDEN_DIM, | |
| seq_len=FP_LENGTH, | |
| num_layers=TRANSFORMER_NUM_LAYERS, | |
| nhead=TRANSFORMER_NHEAD, | |
| dim_feedforward=TRANSFORMER_FF, | |
| dropout=DROPOUT, | |
| emb_dim=600, | |
| ) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| overwrite_output_dir=True, | |
| num_train_epochs=NUM_EPOCHS, | |
| per_device_train_batch_size=TRAIN_BATCH_SIZE, | |
| per_device_eval_batch_size=EVAL_BATCH_SIZE, | |
| eval_accumulation_steps=1000, | |
| gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, | |
| eval_strategy="epoch", | |
| logging_steps=500, | |
| learning_rate=LEARNING_RATE, | |
| weight_decay=WEIGHT_DECAY, | |
| fp16=torch.cuda.is_available(), | |
| save_strategy="no", | |
| disable_tqdm=False, | |
| logging_first_step=True, | |
| report_to=[], | |
| dataloader_num_workers=args.num_workers, | |
| ) | |
| callback = ValLossCallback(best_model_dir=best_model_dir, val_loader=val_loader, patience=10) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| data_collator=collate_batch, | |
| callbacks=[callback], | |
| ) | |
| callback.trainer_ref = trainer | |
| start_time = time.time() | |
| trainer.train() | |
| total_time = time.time() - start_time | |
| best_model_path = os.path.join(best_model_dir, "pytorch_model.bin") | |
| if os.path.exists(best_model_path): | |
| try: | |
| model.load_state_dict(torch.load(best_model_path, map_location=device)) | |
| print(f"\nLoaded best model from {best_model_path}") | |
| except Exception as e: | |
| print(f"\nFailed to load best model from {best_model_path}: {e}") | |
| # Final evaluation | |
| model.eval() | |
| preds_bits_all, true_bits_all = [], [] | |
| logits_masked_final, labels_masked_final = [], [] | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| input_ids = batch["input_ids"].to(device) | |
| labels = batch["labels"].to(device) | |
| attention_mask = batch.get("attention_mask", torch.ones_like(input_ids, dtype=torch.bool)).to(device) | |
| logits = model.token_logits(input_ids=input_ids, attention_mask=attention_mask) | |
| mask = labels != -100 | |
| if mask.sum().item() == 0: | |
| continue | |
| logits_masked_final.append(logits[mask]) | |
| labels_masked_final.append(labels[mask]) | |
| pred_bits = torch.argmax(logits[mask], dim=-1) | |
| true_b = labels[mask] | |
| preds_bits_all.extend(pred_bits.cpu().tolist()) | |
| true_bits_all.extend(true_b.cpu().tolist()) | |
| accuracy = accuracy_score(true_bits_all, preds_bits_all) if len(true_bits_all) > 0 else 0.0 | |
| f1 = f1_score(true_bits_all, preds_bits_all, average="weighted") if len(true_bits_all) > 0 else 0.0 | |
| if len(logits_masked_final) > 0: | |
| all_logits_masked_final = torch.cat(logits_masked_final, dim=0) | |
| all_labels_masked_final = torch.cat(labels_masked_final, dim=0) | |
| loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final.long()) | |
| try: | |
| perplexity_final = float(torch.exp(loss_z_final).cpu().item()) | |
| except Exception: | |
| perplexity_final = float(np.exp(float(loss_z_final.cpu().item()))) | |
| else: | |
| perplexity_final = float("nan") | |
| best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan") | |
| best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None | |
| print(f"\n=== Final Results (evaluated on best saved model) ===") | |
| print(f"Total Training Time (s): {total_time:.2f}") | |
| print(f"Best Epoch (1-based): {best_epoch_num}" if best_epoch_num is not None else "Best Epoch: (none saved)") | |
| print(f"Best Validation Loss: {best_val_loss:.4f}") | |
| print(f"Validation Accuracy: {accuracy:.4f}") | |
| print(f"Validation F1 (weighted): {f1:.4f}") | |
| print(f"Validation Perplexity (classification head): {perplexity_final:.4f}") | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| non_trainable_params = total_params - trainable_params | |
| print(f"Total Parameters: {total_params}") | |
| print(f"Trainable Parameters: {trainable_params}") | |
| print(f"Non-trainable Parameters: {non_trainable_params}") | |
| def main(): | |
| args = parse_args() | |
| train_and_eval(args) | |
| if __name__ == "__main__": | |
| main() | |